Commit 7ac525e7 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add training recipe for RNN-T Emformer ASR model (#2052)

Summary:
Add training recipe for RNN-T Emformer ASR model to examples directory.

Pull Request resolved: https://github.com/pytorch/audio/pull/2052

Reviewed By: nateanl

Differential Revision: D32814096

Pulled By: hwangjeff

fbshipit-source-id: a5153044efc16cb39f0e6413369a6791637af76a
parent 4b11eee8
# Emformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model.
## Usage
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets.
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on LibriSpeech test-clean.
Using the default configuration along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0466 when evaluated on test-clean with [`eval.py`](./eval.py).
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0466 |
| test-other | 0.1239 |
| dev-clean | 0.0445 |
| dev-other | 0.1217 |
Sample SLURM command:
```
srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_bpe_4096.model --use_cuda
```
from argparse import ArgumentParser
import logging
import pathlib
import torch
import torchaudio
from lightning import RNNTModule
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(
seq1.lower().split(), seq2.lower().split()
)
def run_eval(args):
model = RNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(
f"Processed elem {idx}; WER: {total_edit_distance / total_length}"
)
logger.info(f"Final WER: {total_edit_distance / total_length}")
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint_path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global_stats_path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.",
)
parser.add_argument(
"--use_cuda", action="store_true", default=False, help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
16.462461471557617,
17.020158767700195,
17.27733039855957,
17.273637771606445,
17.78028678894043,
18.112783432006836,
18.322141647338867,
18.3536319732666,
18.220436096191406,
17.93610191345215,
17.650646209716797,
17.505868911743164,
17.450956344604492,
17.420780181884766,
17.36254119873047,
17.24843978881836,
17.073762893676758,
16.893953323364258,
16.62371826171875,
16.279895782470703,
16.046218872070312,
15.789617538452148,
15.458984375,
15.335075378417969,
15.103074073791504,
14.993032455444336,
14.818647384643555,
14.713132858276367,
14.576343536376953,
14.482580184936523,
14.431093215942383,
14.392385482788086,
14.357626914978027,
14.335031509399414,
14.344644546508789,
14.341029167175293,
14.338135719299316,
14.311485290527344,
14.266831398010254,
14.205205917358398,
14.159194946289062,
14.07589054107666,
14.02244758605957,
13.954248428344727,
13.897454261779785,
13.856722831726074,
13.80321216583252,
13.75955867767334,
13.718783378601074,
13.67695426940918,
13.626880645751953,
13.554975509643555,
13.465453147888184,
13.372663497924805,
13.269320487976074,
13.184920310974121,
13.094778060913086,
12.998514175415039,
12.891039848327637,
12.765382766723633,
12.638651847839355,
12.50733470916748,
12.345802307128906,
12.195826530456543,
12.019110679626465,
11.842704772949219,
11.680868148803711,
11.518675804138184,
11.37252426147461,
11.252099990844727,
11.12936019897461,
11.029287338256836,
10.927411079406738,
10.825841903686523,
10.717211723327637,
10.499553680419922,
9.722028732299805,
8.256664276123047,
7.897761344909668,
7.252806663513184
],
"invstddev": [
0.2532021571066031,
0.2597563367511928,
0.2579079373215276,
0.2416085222005694,
0.23003407153886749,
0.21714598348479108,
0.20868966256973892,
0.20397882792073063,
0.20346486748979434,
0.20568288111895272,
0.20795624145573485,
0.20848980415063503,
0.20735096423640872,
0.2060772210458722,
0.20577174595523076,
0.20655349986725383,
0.2080547906859301,
0.21015748217276387,
0.2127639989370032,
0.2156462785763535,
0.21848300746868443,
0.22174608140608748,
0.22541974458780933,
0.22897465119671973,
0.23207484606149037,
0.2353556049061462,
0.23820711835547867,
0.24016651485087528,
0.24200318561465783,
0.2435905301766702,
0.24527147180928432,
0.2493368450351618,
0.25120444993308483,
0.2521961451825939,
0.25358032484699955,
0.25349767201088286,
0.2534676894845623,
0.25149125467665234,
0.25001929593946776,
0.25064096375066197,
0.25194505955280033,
0.25270402089338095,
0.2535205901701615,
0.25363568106276674,
0.2535307075541985,
0.25315144026701186,
0.2523683857532224,
0.25200854739575596,
0.2516561583169735,
0.25147053419035553,
0.25187638352086095,
0.25176343344798546,
0.25256615785525305,
0.25310796555079107,
0.2535568871416053,
0.2542411936874833,
0.2544978632482573,
0.2553210332506536,
0.2567248511819892,
0.2559665595456875,
0.2564729970835735,
0.2585267417223537,
0.2573770145474615,
0.2585495460828127,
0.2593605768768532,
0.25906572100606984,
0.26026752519153573,
0.2609952847918467,
0.26222905157170767,
0.26395874733435604,
0.26404203898769246,
0.26501581381370537,
0.2666259054856709,
0.2676190865432322,
0.26813030555166134,
0.26873271506658997,
0.2624062353014993,
0.2289515918968408,
0.22755587298227964,
0.24719513536827162
]
}
\ No newline at end of file
from collections import namedtuple
import json
import math
import os
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
import torchaudio.functional as F
from torchaudio.prototype.rnnt import emformer_rnnt_base
from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch
from pytorch_lightning import LightningModule
Batch = namedtuple(
"Batch", ["features", "feature_lengths", "targets", "target_lengths"]
)
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=400, n_mels=80, hop_length=160
)
def _batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomDataset(torch.utils.data.Dataset):
r"""Sort samples by target length and batch to max token count."""
def __init__(self, base_dataset, max_token_limit):
super().__init__()
self.base_dataset = base_dataset
fileid_to_target_length = {}
idx_target_lengths = [
(idx, self._target_length(fileid, fileid_to_target_length))
for idx, fileid in enumerate(self.base_dataset._walker)
]
assert len(idx_target_lengths) > 0
idx_target_lengths = sorted(
idx_target_lengths, key=lambda x: x[1], reverse=True
)
assert max_token_limit >= idx_target_lengths[0][1]
self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit)
def _target_length(self, fileid, fileid_to_target_length):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt
file_text = os.path.join(
self.base_dataset._path, speaker_id, chapter_id, file_text
)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
def __getitem__(self, idx):
return [self.base_dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TimeMasking(torchaudio.transforms._AxisMasking):
def __init__(
self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False
) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
self.min_mask_p = min_mask_p
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
if self.iid_masks and specgram.dim() == 4:
mask_param = min(
self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1]
)
return F.mask_along_axis_iid(
specgram, mask_param, mask_value, self.axis + 1
)
else:
mask_param = min(
self.mask_param, self.min_mask_p * specgram.shape[self.axis]
)
return F.mask_along_axis(specgram, mask_param, mask_value, self.axis)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [
(min(1.0, self._step_count / self.warmup_updates)) * base_lr
for base_lr in self.base_lrs
]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[
token_index
for token_index in h.tokens[1:]
if token_index not in post_process_remove_list
]
for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch
class RNNTModule(LightningModule):
def __init__(
self,
*,
librispeech_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__()
self.model = emformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, factor=0.96, patience=0
)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
TimeMasking(100, 0.2),
TimeMasking(100, 0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.librispeech_path = librispeech_path
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.blank_idx = self.sp_model.get_piece_size()
def _extract_labels(self, samples: List):
targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor(
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor(
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples
def _step(self, batch, batch_idx, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty(
[batch.targets.size(0), batch.targets.size(1) + 1]
)
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{
"scheduler": self.lr_scheduler,
"monitor": "Losses/val_loss",
"interval": "epoch",
},
{"scheduler": self.warmup_lr_scheduler, "interval": "step"},
],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(
batch.features.to(self.device), batch.feature_lengths.to(self.device), 20
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-clean-360"
),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-clean-100"
),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-other-500"
),
1000,
),
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._train_collate_fn,
num_workers=10,
shuffle=True,
)
return dataloader
def val_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="dev-clean"
),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="dev-other"
),
1000,
),
]
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=None, collate_fn=self._valid_collate_fn, num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="test-clean"
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=self._test_collate_fn
)
return dataloader
from argparse import ArgumentParser
import pathlib
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning import RNNTModule
def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
gpus=args.gpus,
accelerator="gpu",
strategy="ddp",
gradient_clip_val=10.0,
callbacks=callbacks,
)
model = RNNTModule(
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
trainer.fit(model)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--exp_dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global_stats_path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.",
)
parser.add_argument(
"--num_nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment