Commit 33bcb7b0 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Refactor Emformer RNNT recipes (#2212)

Summary:
Consolidates LibriSpeech and TED-LIUM Release 3 Emformer RNN-T training recipes in a single directory.

Pull Request resolved: https://github.com/pytorch/audio/pull/2212

Reviewed By: mthrok

Differential Revision: D34120104

Pulled By: hwangjeff

fbshipit-source-id: 29c6e27195d5998f76d67c35b718110e73529456
parent 87d7694d
# Emformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for an Emformer RNN-T streaming ASR model.
## Usage
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to datasets and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py).
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset.
## Model Types
Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets.
### LibriSpeech
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model_type librispeech --exp_dir ./experiments --dataset_path ./datasets/librispeech --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model
```
Sample SLURM command for evaluation:
```
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda
```
Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py).
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0456 |
| test-other | 0.1066 |
| dev-clean | 0.0415 |
| dev-other | 0.1110 |
[`librispeech/pipeline_demo.py`](./librispeech/pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` bundle that wraps a pre-trained Emformer RNN-T produced by the above recipe to perform streaming and full-context ASR on several LibriSpeech samples.
### TED-LIUM Release 3
Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the TED-LIUM Release 3 model is configured with a vocabulary size of 500. Consequently, the TED-LIUM Release 3 model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol); the rest of the model is identical to the LibriSpeech model.
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model_type tedlium3 --exp_dir ./experiments --dataset_path ./datasets/tedlium --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_500.model --gradient_clip_val 5.0
```
Sample SLURM command for evaluation:
```
srun python eval.py --model_type tedlium3 --checkpoint_path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium_path ./datasets/tedlium --sp_model_path ./spm-bpe-500.model --use_cuda
```
The table below contains WER results for dev and test subsets of TED-LIUM release 3.
| | WER |
|:-----------:|-------------:|
| dev | 0.108 |
| test | 0.098 |
[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table.
import json
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from torchaudio.models import Hypothesis
MODEL_TYPE_LIBRISPEECH = "librispeech"
MODEL_TYPE_TEDLIUM3 = "tedlium3"
DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
GAIN = pow(10, 0.05 * DECIBEL)
spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
def piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
def batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
...@@ -4,7 +4,9 @@ from argparse import ArgumentParser ...@@ -4,7 +4,9 @@ from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from lightning import RNNTModule from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from librispeech.lightning import LibriSpeechRNNTModule
from tedlium3.lightning import TEDLIUM3RNNTModule
logger = logging.getLogger() logger = logging.getLogger()
...@@ -14,23 +16,13 @@ def compute_word_level_distance(seq1, seq2): ...@@ -14,23 +16,13 @@ def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split()) return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args): def run_eval(model):
model = RNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0 total_edit_distance = 0
total_length = 0 total_length = 0
dataloader = model.test_dataloader() dataloader = model.test_dataloader()
with torch.no_grad(): with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader): for idx, (batch, transcripts) in enumerate(dataloader):
actual = sample[0][2] actual = transcripts[0]
predicted = model(batch) predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted) total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split()) total_length += len(actual.split())
...@@ -39,8 +31,28 @@ def run_eval(args): ...@@ -39,8 +31,28 @@ def run_eval(args):
logger.info(f"Final WER: {total_edit_distance / total_length}") logger.info(f"Final WER: {total_edit_distance / total_length}")
def cli_main(): def get_lightning_module(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return LibriSpeechRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return TEDLIUM3RNNTModule.load_from_checkpoint(
args.checkpoint_path,
tedlium_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument( parser.add_argument(
"--checkpoint_path", "--checkpoint_path",
type=pathlib.Path, type=pathlib.Path,
...@@ -53,9 +65,9 @@ def cli_main(): ...@@ -53,9 +65,9 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.", help="Path to JSON file containing feature means and stddevs.",
) )
parser.add_argument( parser.add_argument(
"--librispeech_path", "--dataset_path",
type=pathlib.Path, type=pathlib.Path,
help="Path to LibriSpeech datasets.", help="Path to dataset.",
) )
parser.add_argument( parser.add_argument(
"--sp_model_path", "--sp_model_path",
...@@ -68,8 +80,23 @@ def cli_main(): ...@@ -68,8 +80,23 @@ def cli_main():
default=False, default=False,
help="Run using CUDA.", help="Run using CUDA.",
) )
args = parser.parse_args() parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
run_eval(args) return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
if args.use_cuda:
model = model.to(device="cuda")
run_eval(model)
if __name__ == "__main__": if __name__ == "__main__":
......
"""Generate feature statistics for LibriSpeech training set. """Generate feature statistics for training set.
Example: Example:
python global_stats.py --librispeech_path /home/librispeech python global_stats.py --model_type librispeech --dataset_path /home/librispeech
""" """
import json import json
...@@ -11,19 +11,20 @@ from argparse import ArgumentParser, RawTextHelpFormatter ...@@ -11,19 +11,20 @@ from argparse import ArgumentParser, RawTextHelpFormatter
import torch import torch
import torchaudio import torchaudio
from utils import GAIN, piecewise_linear_log, spectrogram_transform from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger() logger = logging.getLogger()
def parse_args(): def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument( parser.add_argument(
"--librispeech_path", "--dataset_path",
required=True, required=True,
type=pathlib.Path, type=pathlib.Path,
help="Path to LibriSpeech datasets. " help="Path to dataset. "
"All of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.", "For LibriSpeech, all of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
) )
parser.add_argument( parser.add_argument(
"--output_path", "--output_path",
...@@ -56,15 +57,24 @@ def generate_statistics(samples): ...@@ -56,15 +57,24 @@ def generate_statistics(samples):
return E_x, (E_x_2 - E_x ** 2) ** 0.5 return E_x, (E_x_2 - E_x ** 2) ** 0.5
def get_dataset(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return torch.utils.data.ConcatDataset(
[
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-other-500"),
]
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return torchaudio.datasets.TEDLIUM(args.dataset_path, release="release3", subset="train")
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def cli_main(): def cli_main():
args = parse_args() args = parse_args()
dataset = torch.utils.data.ConcatDataset( dataset = get_dataset(args)
[
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="train-other-500"),
]
)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)
mean, stddev = generate_statistics(iter(dataloader)) mean, stddev = generate_statistics(iter(dataloader))
......
import json
import math
import os import os
from collections import namedtuple from typing import List
from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F from common import (
GAIN,
Batch,
FunctionalModule,
GlobalStatsNormalization,
WarmupLR,
batch_by_token_count,
piecewise_linear_log,
post_process_hypos,
spectrogram_transform,
)
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base from torchaudio.models import RNNTBeamSearch, emformer_rnnt_base
from utils import GAIN, piecewise_linear_log, spectrogram_transform
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
def _batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomDataset(torch.utils.data.Dataset): class CustomDataset(torch.utils.data.Dataset):
r"""Sort samples by target length and batch to max token count.""" r"""Sort LibriSpeech samples by target length and batch to max token count."""
def __init__(self, base_dataset, max_token_limit): def __init__(self, base_dataset, max_token_limit):
super().__init__() super().__init__()
...@@ -54,7 +38,7 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -54,7 +38,7 @@ class CustomDataset(torch.utils.data.Dataset):
assert max_token_limit >= idx_target_lengths[0][1] assert max_token_limit >= idx_target_lengths[0][1]
self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit) self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)
def _target_length(self, fileid, fileid_to_target_length): def _target_length(self, fileid, fileid_to_target_length):
if fileid not in fileid_to_target_length: if fileid not in fileid_to_target_length:
...@@ -77,74 +61,7 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -77,74 +61,7 @@ class CustomDataset(torch.utils.data.Dataset):
return len(self.batches) return len(self.batches)
class TimeMasking(torchaudio.transforms._AxisMasking): class LibriSpeechRNNTModule(LightningModule):
def __init__(self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
self.min_mask_p = min_mask_p
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
if self.iid_masks and specgram.dim() == 4:
mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1])
return F.mask_along_axis_iid(specgram, mask_param, mask_value, self.axis + 1)
else:
mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis])
return F.mask_along_axis(specgram, mask_param, mask_value, self.axis)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch
class RNNTModule(LightningModule):
def __init__( def __init__(
self, self,
*, *,
...@@ -157,7 +74,6 @@ class RNNTModule(LightningModule): ...@@ -157,7 +74,6 @@ class RNNTModule(LightningModule):
self.model = emformer_rnnt_base(num_symbols=4097) self.model = emformer_rnnt_base(num_symbols=4097)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0) self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
...@@ -166,8 +82,8 @@ class RNNTModule(LightningModule): ...@@ -166,8 +82,8 @@ class RNNTModule(LightningModule):
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
TimeMasking(100, 0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
TimeMasking(100, 0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
) )
...@@ -219,7 +135,7 @@ class RNNTModule(LightningModule): ...@@ -219,7 +135,7 @@ class RNNTModule(LightningModule):
return Batch(features, feature_lengths, targets, target_lengths) return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List): def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples return self._valid_collate_fn(samples), [sample[2] for sample in samples]
def _step(self, batch, batch_idx, step_type): def _step(self, batch, batch_idx, step_type):
if batch is None: if batch is None:
...@@ -243,11 +159,6 @@ class RNNTModule(LightningModule): ...@@ -243,11 +159,6 @@ class RNNTModule(LightningModule):
return ( return (
[self.optimizer], [self.optimizer],
[ [
{
"scheduler": self.lr_scheduler,
"monitor": "Losses/val_loss",
"interval": "epoch",
},
{"scheduler": self.warmup_lr_scheduler, "interval": "step"}, {"scheduler": self.warmup_lr_scheduler, "interval": "step"},
], ],
) )
......
import json
import math
import os import os
from collections import namedtuple from typing import List
from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from common import (
GAIN,
Batch,
FunctionalModule,
GlobalStatsNormalization,
WarmupLR,
batch_by_token_count,
piecewise_linear_log,
post_process_hypos,
spectrogram_transform,
)
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base from torchaudio.models import RNNTBeamSearch, emformer_rnnt_base
from torchaudio.transforms import TimeMasking
from utils import GAIN, piecewise_linear_log, spectrogram_transform
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
def _batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if target_length == -1:
continue
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomDataset(torch.utils.data.Dataset): class CustomDataset(torch.utils.data.Dataset):
r"""Sort samples by target length and batch to max durations.""" r"""Sort TEDLIUM3 samples by target length and batch to max durations."""
def __init__(self, base_dataset, max_token_limit): def __init__(self, base_dataset, max_token_limit):
super().__init__() super().__init__()
...@@ -46,6 +29,7 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -46,6 +29,7 @@ class CustomDataset(torch.utils.data.Dataset):
idx_target_lengths = [ idx_target_lengths = [
(idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist) (idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist)
] ]
idx_target_lengths = [(idx, length) for idx, length in idx_target_lengths if length != -1]
assert len(idx_target_lengths) > 0 assert len(idx_target_lengths) > 0
...@@ -53,13 +37,13 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -53,13 +37,13 @@ class CustomDataset(torch.utils.data.Dataset):
assert max_token_limit >= idx_target_lengths[-1][1] assert max_token_limit >= idx_target_lengths[-1][1]
self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit) self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)[:100]
def _target_length(self, fileid, line): def _target_length(self, fileid, line):
transcript_path = os.path.join(self.base_dataset._path, "stm", fileid) transcript_path = os.path.join(self.base_dataset._path, "stm", fileid)
with open(transcript_path + ".stm") as f: with open(transcript_path + ".stm") as f:
transcript = f.readlines()[line] transcript = f.readlines()[line]
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6) _, _, _, start_time, end_time, _, transcript = transcript.split(" ", 6)
if transcript.lower() == "ignore_time_segment_in_scoring\n": if transcript.lower() == "ignore_time_segment_in_scoring\n":
return -1 return -1
else: else:
...@@ -72,72 +56,31 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -72,72 +56,31 @@ class CustomDataset(torch.utils.data.Dataset):
return len(self.batches) return len(self.batches)
class FunctionalModule(torch.nn.Module): class EvalDataset(torch.utils.data.IterableDataset):
def __init__(self, functional): def __init__(self, base_dataset):
super().__init__() super().__init__()
self.functional = functional self.base_dataset = base_dataset
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch def __iter__(self):
for sample in iter(self.base_dataset):
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
yield sample
class RNNTModule(LightningModule): class TEDLIUM3RNNTModule(LightningModule):
def __init__( def __init__(
self, self,
*, *,
tedlium_path: str, tedlium_path: str,
sp_model_path: str, sp_model_path: str,
global_stats_path: str, global_stats_path: str,
reduction: str,
): ):
super().__init__() super().__init__()
self.model = emformer_rnnt_base(num_symbols=501) self.model = emformer_rnnt_base(num_symbols=501)
self.loss = torchaudio.transforms.RNNTLoss(reduction=reduction, clamp=1.0) self.loss = torchaudio.transforms.RNNTLoss(reduction="mean", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
...@@ -147,8 +90,8 @@ class RNNTModule(LightningModule): ...@@ -147,8 +90,8 @@ class RNNTModule(LightningModule):
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
) )
...@@ -216,7 +159,7 @@ class RNNTModule(LightningModule): ...@@ -216,7 +159,7 @@ class RNNTModule(LightningModule):
return Batch(features, feature_lengths, targets, target_lengths) return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List): def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples return self._valid_collate_fn(samples), [sample[2] for sample in samples]
def _step(self, batch, batch_idx, step_type): def _step(self, batch, batch_idx, step_type):
if batch is None: if batch is None:
...@@ -280,11 +223,11 @@ class RNNTModule(LightningModule): ...@@ -280,11 +223,11 @@ class RNNTModule(LightningModule):
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test") dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader return dataloader
def dev_dataloader(self): def dev_dataloader(self):
dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev") dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader return dataloader
import logging import logging
import pathlib import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter from argparse import ArgumentParser
from lightning import RNNTModule from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from librispeech.lightning import LibriSpeechRNNTModule
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from tedlium3.lightning import TEDLIUM3RNNTModule
def run_train(args): def get_trainer(args):
checkpoint_dir = args.exp_dir / "checkpoints" checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
...@@ -29,65 +31,68 @@ def run_train(args): ...@@ -29,65 +31,68 @@ def run_train(args):
checkpoint, checkpoint,
train_checkpoint, train_checkpoint,
] ]
trainer = Trainer( return Trainer(
default_root_dir=args.exp_dir, default_root_dir=args.exp_dir,
max_epochs=args.epochs, max_epochs=args.epochs,
num_nodes=args.num_nodes, num_nodes=args.num_nodes,
gpus=args.gpus, gpus=args.gpus,
accelerator="gpu", accelerator="gpu",
strategy="ddp", strategy="ddp",
gradient_clip_val=5.0, gradient_clip_val=args.gradient_clip_val,
callbacks=callbacks, callbacks=callbacks,
) )
model = RNNTModule(
tedlium_path=str(args.tedlium_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
reduction=args.reduction,
)
trainer.fit(model)
def get_lightning_module(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return LibriSpeechRNNTModule(
librispeech_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return TEDLIUM3RNNTModule(
tedlium_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def _parse_args():
parser = ArgumentParser( def parse_args():
description=__doc__, parser = ArgumentParser()
formatter_class=RawTextHelpFormatter, parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument( parser.add_argument(
"--global-stats-path", "--global_stats_path",
default=pathlib.Path("global_stats.json"), default=pathlib.Path("global_stats.json"),
type=pathlib.Path, type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.", help="Path to JSON file containing feature means and stddevs.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--tedlium-path", "--dataset_path",
type=pathlib.Path, type=pathlib.Path,
help="Path to datasets.",
required=True, required=True,
help="Path to TED-LIUM release 3 dataset.",
) )
parser.add_argument( parser.add_argument(
"--reduction", "--sp_model_path",
default="mean", type=pathlib.Path,
type=str, help="Path to SentencePiece model.",
help="Reduction option for RNN Transducer loss function." "(Default: ``mean``)", required=True,
) )
parser.add_argument( parser.add_argument(
"--sp-model-path", "--exp_dir",
default=pathlib.Path("./exp"),
type=pathlib.Path, type=pathlib.Path,
help="Path to SentencePiece model.", help="Directory to save checkpoints and logs to. (Default: './exp')",
) )
parser.add_argument( parser.add_argument(
"--num-nodes", "--num_nodes",
default=1, default=4,
type=int, type=int,
help="Number of nodes to use for training. (Default: 1)", help="Number of nodes to use for training. (Default: 4)",
) )
parser.add_argument( parser.add_argument(
"--gpus", "--gpus",
...@@ -101,20 +106,25 @@ def _parse_args(): ...@@ -101,20 +106,25 @@ def _parse_args():
type=int, type=int,
help="Number of epochs to train for. (Default: 120)", help="Number of epochs to train for. (Default: 120)",
) )
parser.add_argument(
"--gradient_clip_val", default=10.0, type=float, help="Value to clip gradient values to. (Default: 10.0)"
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging") parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args() return parser.parse_args()
def _init_logger(debug): def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s" fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main(): def cli_main():
args = _parse_args() args = parse_args()
_init_logger(args.debug) init_logger(args.debug)
run_train(args) model = get_lightning_module(args)
trainer = get_trainer(args)
trainer.fit(model)
if __name__ == "__main__": if __name__ == "__main__":
......
# Emformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model.
## Usage
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model on LibriSpeech using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full LibriSpeech dataset and the SentencePiece model to be used to encode targets. The script also expects a file (--global_stats_path) that contains training set feature statistics; this file can be generated via [`global_stats.py`](./global_stats.py).
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on LibriSpeech test-clean.
Using the default configuration along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0466 when evaluated on test-clean with [`eval.py`](./eval.py).
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0456 |
| test-other | 0.1066 |
| dev-clean | 0.0415 |
| dev-other | 0.1110 |
Sample SLURM command:
```
srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_bpe_4096.model --use_cuda
```
import pathlib
from argparse import ArgumentParser
from lightning import RNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
gpus=args.gpus,
accelerator="gpu",
strategy="ddp",
gradient_clip_val=10.0,
callbacks=callbacks,
)
model = RNNTModule(
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
trainer.fit(model)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--exp_dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global_stats_path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--num_nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
import math
import torch
import torchaudio
DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
GAIN = pow(10, 0.05 * DECIBEL)
spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
# Emformer RNN-T ASR Example for TED-LIUM release 3 dataset
This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model.
## Usage
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model on TED-LIUM release 3 using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full TED-LIUM release 3 dataset and the SentencePiece model to be used to encode targets.
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --exp-dir ./experiments --tedlium-path ./datasets/ --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_500.model
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on TED-LIUM release 3 test set.
The table below contains WER results for dev and test subsets of TED-LIUM release 3.
| | WER |
|:-----------:|-------------:|
| dev | 0.108 |
| test | 0.098 |
Sample SLURM command:
```
srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium-path ./datasets/ --sp-model-path ./spm-bpe-500.model --use-cuda
```
### Evaluation using `torchaudio.pipelines.EMFORMER_RNNT_BASE_TEDLIUM3` bundle
[`eval_pipeline.py`](./eval_pipeline.py) evaluates the `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3.
You should be able to get identical WER results in the above table.
Sample SLURM command:
```
srun python eval_pipeline.py --tedlium-path ./datasets/ --use-cuda
```
"""Generate feature statistics for TED-LIUM release 3 training set.
Example:
python compute_global_stats.py --tedlium-path /home/datasets/
"""
import json
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torchaudio
from utils import GAIN, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './global_stats.json')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _compute_stats(dataset):
E_x = 0.0
E_x_2 = 0.0
N = 0.0
for idx, data in enumerate(dataset):
waveform = data[0].squeeze()
mel_spec = spectrogram_transform(waveform)
scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN)
mel_sum = scaled_mel_spec.sum(-1)
mel_sum_sq = scaled_mel_spec.pow(2).sum(-1)
M = scaled_mel_spec.size(1)
E_x = E_x * (N / (N + M)) + mel_sum / (N + M)
E_x_2 = E_x_2 * (N / (N + M)) + mel_sum_sq / (N + M)
N += M
if idx % 100 == 0:
logger.info(f"Processed {idx}")
return E_x, (E_x_2 - E_x ** 2) ** 0.5
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
dataset = torchaudio.datasets.TEDLIUM(args.tedlium_path, release="release3", subset="train")
mean, std = _compute_stats(dataset)
invstd = 1 / std
stats_dict = {
"mean": mean.tolist(),
"invstddev": invstd.tolist(),
}
with open(args.output_path, "w") as f:
json.dump(stats_dict, f, indent=2)
if __name__ == "__main__":
cli_main()
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from lightning import RNNTModule
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def _eval_subset(model, subset):
total_edit_distance = 0.0
total_length = 0.0
if subset == "dev":
dataloader = model.dev_dataloader()
else:
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval(args):
model = RNNTModule.load_from_checkpoint(
args.checkpoint_path,
tedlium_path=str(args.tedlium_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
reduction="mean",
).eval()
if args.use_cuda:
model = model.to(device="cuda")
_eval_subset(model, "dev")
_eval_subset(model, "test")
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_eval(args)
if __name__ == "__main__":
cli_main()
"""Train the SentencePiece model by using the transcripts of TED-LIUM release 3 training set.
Example:
python train_spm.py --tedlium-path /home/datasets/
"""
import logging
import os
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-dir",
default=pathlib.Path("./"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _extract_train_text(tedlium_path, output_dir):
stm_path = tedlium_path / "TEDLIUM_release-3/data/stm/"
transcripts = []
for file in sorted(os.listdir(stm_path)):
if file.endswith(".stm"):
file = os.path.join(stm_path, file)
with open(file) as f:
for line in f.readlines():
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = line.split(" ", 6)
if transcript == "ignore_time_segment_in_scoring\n":
continue
else:
transcript = transcript.lower().replace("<unk>", "<garbage>")
transcripts.append(transcript)
with open(output_dir / "text_train.txt", "w") as f:
f.writelines(transcripts)
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
_extract_train_text(args.tedlium_path, args.output_dir)
spm.SentencePieceTrainer.train(
input=args.output_dir / "text_train.txt",
vocab_size=500,
model_prefix="spm_bpe_500",
model_type="bpe",
input_sentence_size=100000000,
character_coverage=1.0,
user_defined_symbols=["<garbage>"],
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
logger.info("Successfully trained the sentencepiece model")
if __name__ == "__main__":
cli_main()
../librispeech_emformer_rnnt/utils.py
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment