Commit 69467ea5 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Refactor LibriSpeech Conformer RNN-T recipe (#2366)

Summary:
Modifies the example LibriSpeech Conformer RNN-T recipe as follows:
- Moves data loading and transforms logic from lightning module to data module (improves generalizability and reusability of lightning module and data module).
- Moves transforms logic from dataloader collator function to dataset (resolves dataloader multiprocessing issues on certain platforms).
- Replaces lambda functions with `partial` equivalents (resolves pickling issues in certain runtime environments).
- Modifies training script to allow for specifying path model checkpoint to restart training from.

Pull Request resolved: https://github.com/pytorch/audio/pull/2366

Reviewed By: mthrok

Differential Revision: D36305028

Pulled By: hwangjeff

fbshipit-source-id: 0b768da5d5909136c55418bf0a3c2ddd0c5683ba
parent 93c26d63
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=700,
batch_size=2,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
......@@ -4,7 +4,7 @@ from argparse import ArgumentParser
import torch
import torchaudio
from lightning import ConformerRNNTModule
from lightning import ConformerRNNTModule, get_data_module
logger = logging.getLogger()
......@@ -15,19 +15,15 @@ def compute_word_level_distance(seq1, seq2):
def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model_path=str(args.sp_model_path)).eval()
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
......@@ -45,6 +41,7 @@ def cli_main():
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global-stats-path",
......@@ -56,11 +53,13 @@ def cli_main():
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--use-cuda",
......
import json
import logging
import math
import os
import random
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule, seed_everything
from data_module import LibriSpeechDataModule
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base
from transforms import Batch, TrainTransform, ValTransform, TestTransform
logger = logging.getLogger()
seed_everything(1)
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
_expected_spm_vocab_size = 1023
def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_token_limit >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
......@@ -189,13 +74,7 @@ def post_process_hypos(
class ConformerRNNTModule(LightningModule):
def __init__(
self,
*,
librispeech_path: str,
sp_model_path: str,
global_stats_path: str,
):
def __init__(self, sp_model_path):
super().__init__()
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
......@@ -214,65 +93,8 @@ class ConformerRNNTModule(LightningModule):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
)
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.automatic_optimization = False
def _extract_labels(self, samples: List):
targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples
def _step(self, batch, _, step_type):
if batch is None:
return None
......@@ -348,55 +170,14 @@ class ConformerRNNTModule(LightningModule):
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 50, shuffle=False, sample_limit=2)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=self._train_collate_fn,
num_workers=10,
batch_size=None,
shuffle=True,
)
return dataloader
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 1, sample_limit=2)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
import pathlib
from argparse import ArgumentParser
from lightning import ConformerRNNTModule
from pytorch_lightning import Trainer
from lightning import ConformerRNNTModule, get_data_module
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin
def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
......@@ -42,16 +43,19 @@ def run_train(args):
reload_dataloaders_every_n_epochs=1,
)
model = ConformerRNNTModule(
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
trainer.fit(model)
model = ConformerRNNTModule(str(args.sp_model_path))
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
......@@ -68,11 +72,13 @@ def cli_main():
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--nodes",
......@@ -93,7 +99,6 @@ def cli_main():
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
......
import json
import math
from collections import namedtuple
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.val_transforms = ValTransform(global_stats_path, sp_model_path)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment