Commit c262758b authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add Conformer RNN-T LibriSpeech training recipe (#2329)

Summary:
Adds Conformer RNN-T LibriSpeech training recipe to examples directory.

Produces 30M-parameter model that achieves the following WER:

|                     |          WER |
|:-------------------:|-------------:|
| test-clean          |       0.0310 |
| test-other          |       0.0805 |
| dev-clean           |       0.0314 |
| dev-other           |       0.0827 |

Pull Request resolved: https://github.com/pytorch/audio/pull/2329

Reviewed By: xiaohui-zhang

Differential Revision: D35578727

Pulled By: hwangjeff

fbshipit-source-id: afa9146c5b647727b8605d104d928110a1d3976d
parent fb51cecc
# Conformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model.
## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.
To install the nightly, follow the directions at <https://pytorch.org/>.
To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
```
## Usage
### Training
[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py).
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py).
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_unigram_1023.model --epochs 160
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model on LibriSpeech test-clean.
Sample SLURM command:
```
srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=159.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_unigram_1023.model --use_cuda
```
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0310 |
| test-other | 0.0805 |
| dev-clean | 0.0314 |
| dev-other | 0.0827 |
import logging
import pathlib
from argparse import ArgumentParser
import torch
import torchaudio
from lightning import ConformerRNNTModule
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
15.058613777160645,
16.34557342529297,
16.34653663635254,
16.240671157836914,
17.45355224609375,
17.445302963256836,
17.52323341369629,
18.076807022094727,
17.699262619018555,
17.706790924072266,
17.24724578857422,
17.153791427612305,
17.213361740112305,
17.347240447998047,
17.331117630004883,
17.21516227722168,
17.030071258544922,
16.818960189819336,
16.573062896728516,
16.29717254638672,
16.00996971130371,
15.794167518615723,
15.616395950317383,
15.459056854248047,
15.306838989257812,
15.199165344238281,
15.208144187927246,
14.883454322814941,
14.787869453430176,
14.947835922241211,
14.5912504196167,
14.76955509185791,
14.617781639099121,
14.840407371520996,
14.83073616027832,
14.909119606018066,
14.89070987701416,
14.918207168579102,
14.939517974853516,
14.913643836975098,
14.863334655761719,
14.803299903869629,
14.751264572143555,
14.688116073608398,
14.63498306274414,
14.615056037902832,
14.680213928222656,
14.616259574890137,
14.707776069641113,
14.630264282226562,
14.644737243652344,
14.547430038452148,
14.529033660888672,
14.49357795715332,
14.411538124084473,
14.33312702178955,
14.260393142700195,
14.204919815063477,
14.130182266235352,
14.06987476348877,
14.010197639465332,
13.938552856445312,
13.750232696533203,
13.607213973999023,
13.457777976989746,
13.31512451171875,
13.167718887329102,
13.019341468811035,
12.8869047164917,
12.795098304748535,
12.685126304626465,
12.620392799377441,
12.58949089050293,
12.537697792053223,
12.496938705444336,
12.410022735595703,
12.346826553344727,
12.221966743469238,
12.122841835021973,
12.005624771118164
],
"invstddev": [
0.25952333211898804,
0.2590482831001282,
0.24866817891597748,
0.24776232242584229,
0.22200720012187958,
0.21363843977451324,
0.20652402937412262,
0.19909949600696564,
0.2021811604499817,
0.20355898141860962,
0.20546883344650269,
0.2061648815870285,
0.20569036900997162,
0.20412985980510712,
0.20357738435268402,
0.2041499763727188,
0.2055872678756714,
0.20807604491710663,
0.21054454147815704,
0.21341396868228912,
0.21418628096580505,
0.22065168619155884,
0.2248840034008026,
0.22723940014839172,
0.230172261595726,
0.23371541500091553,
0.23734734952449799,
0.23960146307945251,
0.24088498950004578,
0.241532102227211,
0.24218633770942688,
0.24371792376041412,
0.2447739839553833,
0.25564682483673096,
0.2632736265659332,
0.2549223005771637,
0.24608071148395538,
0.2464841604232788,
0.2470586597919464,
0.24785254895687103,
0.24904784560203552,
0.2503036856651306,
0.25226327776908875,
0.2532329559326172,
0.2527913451194763,
0.2518651783466339,
0.2504975199699402,
0.24836081266403198,
0.24765831232070923,
0.24767662584781647,
0.24965286254882812,
0.2501370906829834,
0.2508895993232727,
0.2512582540512085,
0.25150999426841736,
0.2525503635406494,
0.25313329696655273,
0.2534785270690918,
0.25330957770347595,
0.25366073846817017,
0.25502219796180725,
0.2608155608177185,
0.25662899017333984,
0.2558451294898987,
0.25671014189720154,
0.2577403485774994,
0.25914356112480164,
0.2596718966960907,
0.25953933596611023,
0.2610883116722107,
0.26132410764694214,
0.26272818446159363,
0.26397505402565,
0.26440608501434326,
0.26543495059013367,
0.26753780245780945,
0.26935192942619324,
0.26732245087623596,
0.26666897535324097,
0.2663257420063019
]
}
import json
import logging
import math
import os
import random
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule, seed_everything
from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base
logger = logging.getLogger()
seed_everything(1)
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
_expected_spm_vocab_size = 1023
def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_token_limit >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(
self,
*,
librispeech_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__()
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
# ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration.
# For greater customizability, please refer to ``conformer_rnnt_model``.
self.model = conformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
)
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.automatic_optimization = False
def _extract_labels(self, samples: List):
targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0)
opt.step()
# step every epoch
sch = self.lr_schedulers()
if self.trainer.is_last_batch:
sch.step()
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 50, shuffle=False, sample_limit=2)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=self._train_collate_fn,
num_workers=10,
batch_size=None,
shuffle=True,
)
return dataloader
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 1, sample_limit=2)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
import pathlib
from argparse import ArgumentParser
from lightning import ConformerRNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin
def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
train_checkpoint,
lr_monitor,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPPlugin(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
model = ConformerRNNTModule(
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
trainer.fit(model)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment