Commit 157cb2a2 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add global stats script and new json for LibriSpeech RNN-T training recipe (#2183)

Summary:
Adds script for generating global feature statistics along with new feature statistics json for LibriSpeech RNN-T training recipe.

Pull Request resolved: https://github.com/pytorch/audio/pull/2183

Reviewed By: mthrok

Differential Revision: D33902377

Pulled By: hwangjeff

fbshipit-source-id: ec347a685ae67aefc485084aac6ed2efd653250f
parent f654b2c9
......@@ -6,7 +6,7 @@ This directory contains sample implementations of training and evaluation pipeli
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets.
[`train.py`](./train.py) trains an Emformer RNN-T model on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py).
Sample SLURM command:
```
......@@ -23,10 +23,10 @@ The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0466 |
| test-other | 0.1239 |
| dev-clean | 0.0445 |
| dev-other | 0.1217 |
| test-clean | 0.0456 |
| test-other | 0.1066 |
| dev-clean | 0.0415 |
| dev-other | 0.1110 |
Sample SLURM command:
......
{
"mean": [
16.462461471557617,
17.020158767700195,
17.27733039855957,
17.273637771606445,
17.78028678894043,
18.112783432006836,
18.322141647338867,
18.3536319732666,
18.220436096191406,
17.93610191345215,
17.650646209716797,
17.505868911743164,
17.450956344604492,
17.420780181884766,
17.36254119873047,
17.24843978881836,
17.073762893676758,
16.893953323364258,
16.62371826171875,
16.279895782470703,
16.046218872070312,
15.789617538452148,
15.458984375,
15.335075378417969,
15.103074073791504,
14.993032455444336,
14.818647384643555,
14.713132858276367,
14.576343536376953,
14.482580184936523,
14.431093215942383,
14.392385482788086,
14.357626914978027,
14.335031509399414,
14.344644546508789,
14.341029167175293,
14.338135719299316,
14.311485290527344,
14.266831398010254,
14.205205917358398,
14.159194946289062,
14.07589054107666,
14.02244758605957,
13.954248428344727,
13.897454261779785,
13.856722831726074,
13.80321216583252,
13.75955867767334,
13.718783378601074,
13.67695426940918,
13.626880645751953,
13.554975509643555,
13.465453147888184,
13.372663497924805,
13.269320487976074,
13.184920310974121,
13.094778060913086,
12.998514175415039,
12.891039848327637,
12.765382766723633,
12.638651847839355,
12.50733470916748,
12.345802307128906,
12.195826530456543,
12.019110679626465,
11.842704772949219,
11.680868148803711,
11.518675804138184,
11.37252426147461,
11.252099990844727,
11.12936019897461,
11.029287338256836,
10.927411079406738,
10.825841903686523,
10.717211723327637,
10.499553680419922,
9.722028732299805,
8.256664276123047,
7.897761344909668,
7.252806663513184
15.058613777160645,
16.34557342529297,
16.34653663635254,
16.240671157836914,
17.45355224609375,
17.445302963256836,
17.52323341369629,
18.076807022094727,
17.699262619018555,
17.706790924072266,
17.24724578857422,
17.153791427612305,
17.213361740112305,
17.347240447998047,
17.331117630004883,
17.21516227722168,
17.030071258544922,
16.818960189819336,
16.573062896728516,
16.29717254638672,
16.00996971130371,
15.794167518615723,
15.616395950317383,
15.459056854248047,
15.306838989257812,
15.199165344238281,
15.208144187927246,
14.883454322814941,
14.787869453430176,
14.947835922241211,
14.5912504196167,
14.76955509185791,
14.617781639099121,
14.840407371520996,
14.83073616027832,
14.909119606018066,
14.89070987701416,
14.918207168579102,
14.939517974853516,
14.913643836975098,
14.863334655761719,
14.803299903869629,
14.751264572143555,
14.688116073608398,
14.63498306274414,
14.615056037902832,
14.680213928222656,
14.616259574890137,
14.707776069641113,
14.630264282226562,
14.644737243652344,
14.547430038452148,
14.529033660888672,
14.49357795715332,
14.411538124084473,
14.33312702178955,
14.260393142700195,
14.204919815063477,
14.130182266235352,
14.06987476348877,
14.010197639465332,
13.938552856445312,
13.750232696533203,
13.607213973999023,
13.457777976989746,
13.31512451171875,
13.167718887329102,
13.019341468811035,
12.8869047164917,
12.795098304748535,
12.685126304626465,
12.620392799377441,
12.58949089050293,
12.537697792053223,
12.496938705444336,
12.410022735595703,
12.346826553344727,
12.221966743469238,
12.122841835021973,
12.005624771118164
],
"invstddev": [
0.2532021571066031,
0.2597563367511928,
0.2579079373215276,
0.2416085222005694,
0.23003407153886749,
0.21714598348479108,
0.20868966256973892,
0.20397882792073063,
0.20346486748979434,
0.20568288111895272,
0.20795624145573485,
0.20848980415063503,
0.20735096423640872,
0.2060772210458722,
0.20577174595523076,
0.20655349986725383,
0.2080547906859301,
0.21015748217276387,
0.2127639989370032,
0.2156462785763535,
0.21848300746868443,
0.22174608140608748,
0.22541974458780933,
0.22897465119671973,
0.23207484606149037,
0.2353556049061462,
0.23820711835547867,
0.24016651485087528,
0.24200318561465783,
0.2435905301766702,
0.24527147180928432,
0.2493368450351618,
0.25120444993308483,
0.2521961451825939,
0.25358032484699955,
0.25349767201088286,
0.2534676894845623,
0.25149125467665234,
0.25001929593946776,
0.25064096375066197,
0.25194505955280033,
0.25270402089338095,
0.2535205901701615,
0.25363568106276674,
0.2535307075541985,
0.25315144026701186,
0.2523683857532224,
0.25200854739575596,
0.2516561583169735,
0.25147053419035553,
0.25187638352086095,
0.25176343344798546,
0.25256615785525305,
0.25310796555079107,
0.2535568871416053,
0.2542411936874833,
0.2544978632482573,
0.2553210332506536,
0.2567248511819892,
0.2559665595456875,
0.2564729970835735,
0.2585267417223537,
0.2573770145474615,
0.2585495460828127,
0.2593605768768532,
0.25906572100606984,
0.26026752519153573,
0.2609952847918467,
0.26222905157170767,
0.26395874733435604,
0.26404203898769246,
0.26501581381370537,
0.2666259054856709,
0.2676190865432322,
0.26813030555166134,
0.26873271506658997,
0.2624062353014993,
0.2289515918968408,
0.22755587298227964,
0.24719513536827162
0.25952333211898804,
0.2590482831001282,
0.24866817891597748,
0.24776232242584229,
0.22200720012187958,
0.21363843977451324,
0.20652402937412262,
0.19909949600696564,
0.2021811604499817,
0.20355898141860962,
0.20546883344650269,
0.2061648815870285,
0.20569036900997162,
0.20412985980510712,
0.20357738435268402,
0.2041499763727188,
0.2055872678756714,
0.20807604491710663,
0.21054454147815704,
0.21341396868228912,
0.21418628096580505,
0.22065168619155884,
0.2248840034008026,
0.22723940014839172,
0.230172261595726,
0.23371541500091553,
0.23734734952449799,
0.23960146307945251,
0.24088498950004578,
0.241532102227211,
0.24218633770942688,
0.24371792376041412,
0.2447739839553833,
0.25564682483673096,
0.2632736265659332,
0.2549223005771637,
0.24608071148395538,
0.2464841604232788,
0.2470586597919464,
0.24785254895687103,
0.24904784560203552,
0.2503036856651306,
0.25226327776908875,
0.2532329559326172,
0.2527913451194763,
0.2518651783466339,
0.2504975199699402,
0.24836081266403198,
0.24765831232070923,
0.24767662584781647,
0.24965286254882812,
0.2501370906829834,
0.2508895993232727,
0.2512582540512085,
0.25150999426841736,
0.2525503635406494,
0.25313329696655273,
0.2534785270690918,
0.25330957770347595,
0.25366073846817017,
0.25502219796180725,
0.2608155608177185,
0.25662899017333984,
0.2558451294898987,
0.25671014189720154,
0.2577403485774994,
0.25914356112480164,
0.2596718966960907,
0.25953933596611023,
0.2610883116722107,
0.26132410764694214,
0.26272818446159363,
0.26397505402565,
0.26440608501434326,
0.26543495059013367,
0.26753780245780945,
0.26935192942619324,
0.26732245087623596,
0.26666897535324097,
0.2663257420063019
]
}
"""Generate feature statistics for LibriSpeech training set.
Example:
python global_stats.py --librispeech_path /home/librispeech
"""
import json
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from utils import GAIN, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger()
def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech_path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech datasets. "
"All of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
)
parser.add_argument(
"--output_path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './global_stats.json')",
)
return parser.parse_args()
def generate_statistics(samples):
E_x = 0
E_x_2 = 0
N = 0
for idx, sample in enumerate(samples):
mel_spec = spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN)
sum = scaled_mel_spec.sum(0)
sq_sum = scaled_mel_spec.pow(2).sum(0)
M = scaled_mel_spec.size(0)
E_x = E_x * (N / (N + M)) + sum / (N + M)
E_x_2 = E_x_2 * (N / (N + M)) + sq_sum / (N + M)
N += M
if idx % 100 == 0:
logger.info(f"Processed {idx}")
return E_x, (E_x_2 - E_x ** 2) ** 0.5
def cli_main():
args = parse_args()
dataset = torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-other-500"),
]
)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)
mean, stddev = generate_statistics(iter(dataloader))
json_str = json.dumps({"mean": mean.tolist(), "invstddev": (1 / stddev).tolist()}, indent=2)
with open(args.output_path, "w") as f:
f.write(json_str)
if __name__ == "__main__":
cli_main()
......@@ -10,17 +10,12 @@ import torchaudio
import torchaudio.functional as F
from pytorch_lightning import LightningModule
from torchaudio.prototype.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
from utils import GAIN, piecewise_linear_log, spectrogram_transform
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
......@@ -119,12 +114,6 @@ class GlobalStatsNormalization(torch.nn.Module):
return (input - self.mean) * self.invstddev
def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
......@@ -172,7 +161,7 @@ class RNNTModule(LightningModule):
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
......@@ -183,7 +172,7 @@ class RNNTModule(LightningModule):
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
......@@ -206,14 +195,14 @@ class RNNTModule(LightningModule):
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
......
import math
import torch
import torchaudio
DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
GAIN = pow(10, 0.05 * DECIBEL)
spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment