import json import logging import math import os import random from collections import namedtuple from typing import List, Tuple import sentencepiece as spm import torch import torchaudio from pytorch_lightning import LightningModule, seed_everything from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.prototype.models import conformer_rnnt_base logger = logging.getLogger() seed_everything(1) Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) _gain = pow(10, 0.05 * _decibel) _spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) _expected_spm_vocab_size = 1023 def _piecewise_linear_log(x): x[x > math.e] = torch.log(x[x > math.e]) x[x <= math.e] = x[x <= math.e] / math.e return x def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None): batches = [] current_batch = [] current_token_count = 0 for idx, target_length in idx_target_lengths: if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit): batches.append(current_batch) current_batch = [idx] current_token_count = target_length else: current_batch.append(idx) current_token_count += target_length if current_batch: batches.append(current_batch) return batches def get_sample_lengths(librispeech_dataset): fileid_to_target_length = {} def _target_length(fileid): if fileid not in fileid_to_target_length: speaker_id, chapter_id, _ = fileid.split("-") file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text) with open(file_text) as ft: for line in ft: fileid_text, transcript = line.strip().split(" ", 1) fileid_to_target_length[fileid_text] = len(transcript) return fileid_to_target_length[fileid] return [_target_length(fileid) for fileid in librispeech_dataset._walker] class CustomBucketDataset(torch.utils.data.Dataset): def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None): super().__init__() assert len(dataset) == len(lengths) self.dataset = dataset max_length = max(lengths) min_length = min(lengths) assert max_token_limit >= max_length buckets = torch.linspace(min_length, max_length, num_buckets) lengths = torch.tensor(lengths) bucket_assignments = torch.bucketize(lengths, buckets) idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)] if shuffle: idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets)) else: idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True) sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2]) self.batches = _batch_by_token_count( [(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit ) def __getitem__(self, idx): return [self.dataset[subidx] for subidx in self.batches[idx]] def __len__(self): return len(self.batches) class FunctionalModule(torch.nn.Module): def __init__(self, functional): super().__init__() self.functional = functional def forward(self, input): return self.functional(input) class GlobalStatsNormalization(torch.nn.Module): def __init__(self, global_stats_path): super().__init__() with open(global_stats_path) as f: blob = json.loads(f.read()) self.mean = torch.tensor(blob["mean"]) self.invstddev = torch.tensor(blob["invstddev"]) def forward(self, input): return (input - self.mean) * self.invstddev class WarmupLR(torch.optim.lr_scheduler._LRScheduler): r"""Learning rate scheduler that performs linear warmup and exponential annealing. Args: optimizer (torch.optim.Optimizer): optimizer to use. warmup_steps (int): number of scheduler steps for which to warm up learning rate. force_anneal_step (int): scheduler step at which annealing of learning rate begins. anneal_factor (float): factor to scale base learning rate by at each annealing step. last_epoch (int, optional): The index of last epoch. (Default: -1) verbose (bool, optional): If ``True``, prints a message to stdout for each update. (Default: ``False``) """ def __init__( self, optimizer: torch.optim.Optimizer, warmup_steps: int, force_anneal_step: int, anneal_factor: float, last_epoch=-1, verbose=False, ): self.warmup_steps = warmup_steps self.force_anneal_step = force_anneal_step self.anneal_factor = anneal_factor super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) def get_lr(self): if self._step_count < self.force_anneal_step: return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs] else: scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step) return [scaling_factor * base_lr for base_lr in self.base_lrs] def post_process_hypos( hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor ) -> List[Tuple[str, float, List[int], List[int]]]: post_process_remove_list = [ sp_model.unk_id(), sp_model.eos_id(), sp_model.pad_id(), ] filtered_hypo_tokens = [ [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos ] hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] hypos_ali = [h.alignment[1:] for h in hypos] hypos_ids = [h.tokens[1:] for h in hypos] hypos_score = [[math.exp(h.score)] for h in hypos] nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) return nbest_batch class ConformerRNNTModule(LightningModule): def __init__( self, *, librispeech_path: str, sp_model_path: str, global_stats_path: str, ): super().__init__() self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) spm_vocab_size = self.sp_model.get_piece_size() assert spm_vocab_size == _expected_spm_vocab_size, ( "The model returned by conformer_rnnt_base expects a SentencePiece model of " f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size " f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model." ) self.blank_idx = spm_vocab_size # ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration. # For greater customizability, please refer to ``conformer_rnnt_model``. self.model = conformer_rnnt_base() self.loss = torchaudio.transforms.RNNTLoss(reduction="sum") self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96) self.train_data_pipeline = torch.nn.Sequential( FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), GlobalStatsNormalization(global_stats_path), FunctionalModule(lambda x: x.transpose(1, 2)), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2), FunctionalModule(lambda x: x.transpose(1, 2)), ) self.valid_data_pipeline = torch.nn.Sequential( FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)), GlobalStatsNormalization(global_stats_path), ) self.librispeech_path = librispeech_path self.train_dataset_lengths = None self.val_dataset_lengths = None self.automatic_optimization = False def _extract_labels(self, samples: List): targets = [self.sp_model.encode(sample[2].lower()) for sample in samples] lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32) targets = torch.nn.utils.rnn.pad_sequence( [torch.tensor(elem) for elem in targets], batch_first=True, padding_value=1.0, ).to(dtype=torch.int32) return targets, lengths def _train_extract_features(self, samples: List): mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) features = self.train_data_pipeline(features) lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) return features, lengths def _valid_extract_features(self, samples: List): mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples] features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) features = self.valid_data_pipeline(features) lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32) return features, lengths def _train_collate_fn(self, samples: List): features, feature_lengths = self._train_extract_features(samples) targets, target_lengths = self._extract_labels(samples) return Batch(features, feature_lengths, targets, target_lengths) def _valid_collate_fn(self, samples: List): features, feature_lengths = self._valid_extract_features(samples) targets, target_lengths = self._extract_labels(samples) return Batch(features, feature_lengths, targets, target_lengths) def _test_collate_fn(self, samples: List): return self._valid_collate_fn(samples), samples def _step(self, batch, _, step_type): if batch is None: return None prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1]) prepended_targets[:, 1:] = batch.targets prepended_targets[:, 0] = self.blank_idx prepended_target_lengths = batch.target_lengths + 1 output, src_lengths, _, _ = self.model( batch.features, batch.feature_lengths, prepended_targets, prepended_target_lengths, ) loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths) self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True) return loss def configure_optimizers(self): return ( [self.optimizer], [{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}], ) def forward(self, batch: Batch): decoder = RNNTBeamSearch(self.model, self.blank_idx) hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20) return post_process_hypos(hypotheses, self.sp_model)[0][0] def training_step(self, batch: Batch, batch_idx): """Custom training step. By default, DDP does the following on each train step: - For each GPU, compute loss and gradient on shard of training data. - Sync and average gradients across all GPUs. The final gradient is (sum of gradients across all GPUs) / N, where N is the world size (total number of GPUs). - Update parameters on each GPU. Here, we do the following: - For k-th GPU, compute loss and scale it by (N / B_total), where B_total is the sum of batch sizes across all GPUs. Compute gradient from scaled loss. - Sync and average gradients across all GPUs. The final gradient is (sum of gradients across all GPUs) / B_total. - Update parameters on each GPU. Doing so allows us to account for the variability in batch sizes that variable-length sequential data commonly yields. """ opt = self.optimizers() opt.zero_grad() loss = self._step(batch, batch_idx, "train") batch_size = batch.features.size(0) batch_sizes = self.all_gather(batch_size) self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True) loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size self.manual_backward(loss) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0) opt.step() # step every epoch sch = self.lr_schedulers() if self.trainer.is_last_batch: sch.step() return loss def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx, "val") def test_step(self, batch, batch_idx): return self._step(batch, batch_idx, "test") def train_dataloader(self): datasets = [ torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"), torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"), torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"), ] if not self.train_dataset_lengths: self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] dataset = torch.utils.data.ConcatDataset( [ CustomBucketDataset(dataset, lengths, 700, 50, shuffle=False, sample_limit=2) for dataset, lengths in zip(datasets, self.train_dataset_lengths) ] ) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=self._train_collate_fn, num_workers=10, batch_size=None, shuffle=True, ) return dataloader def val_dataloader(self): datasets = [ torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"), torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"), ] if not self.val_dataset_lengths: self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets] dataset = torch.utils.data.ConcatDataset( [ CustomBucketDataset(dataset, lengths, 700, 1, sample_limit=2) for dataset, lengths in zip(datasets, self.val_dataset_lengths) ] ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=None, collate_fn=self._valid_collate_fn, num_workers=10, ) return dataloader def test_dataloader(self): dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean") dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) return dataloader