Commit 8f68b3f0 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add training recipe of Emformer trained on TED-LIUM release 3 dataset (#2195)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2195

Reviewed By: hwangjeff

Differential Revision: D33950179

Pulled By: nateanl

fbshipit-source-id: 5fcfa4f433fffdcbb3b8e97f7c90fb8f723a30a2
parent c00f65da
# Emformer RNN-T ASR Example for TED-LIUM release 3 dataset
This directory contains sample implementations of training and evaluation pipelines for an on-device-oriented streaming-capable Emformer RNN-T ASR model.
## Usage
### Training
[`train.py`](./train.py) trains an Emformer RNN-T model on TED-LIUM release 3 using PyTorch Lightning. Note that the script expects users to have access to GPU nodes for training and provide paths to the full TED-LIUM release 3 dataset and the SentencePiece model to be used to encode targets.
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --exp-dir ./experiments --tedlium-path ./datasets/ --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_500.model
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on TED-LIUM release 3 test set.
The table below contains WER results for dev and test subsets of TED-LIUM release 3.
| | WER |
|:-----------:|-------------:|
| dev | 0.108 |
| test | 0.098 |
Sample SLURM command:
```
srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium-path ./datasets/ --sp-model-path ./spm-bpe-500.model --use-cuda
```
"""Generate feature statistics for TED-LIUM release 3 training set.
Example:
python compute_global_stats.py --tedlium-path /home/datasets/
"""
import json
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torchaudio
from utils import GAIN, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './global_stats.json')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _compute_stats(dataset):
E_x = 0.0
E_x_2 = 0.0
N = 0.0
for idx, data in enumerate(dataset):
waveform = data[0].squeeze()
mel_spec = spectrogram_transform(waveform)
scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN)
mel_sum = scaled_mel_spec.sum(-1)
mel_sum_sq = scaled_mel_spec.pow(2).sum(-1)
M = scaled_mel_spec.size(1)
E_x = E_x * (N / (N + M)) + mel_sum / (N + M)
E_x_2 = E_x_2 * (N / (N + M)) + mel_sum_sq / (N + M)
N += M
if idx % 100 == 0:
logger.info(f"Processed {idx}")
return E_x, (E_x_2 - E_x ** 2) ** 0.5
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
dataset = torchaudio.datasets.TEDLIUM(args.tedlium_path, release="release3", subset="train")
mean, std = _compute_stats(dataset)
invstd = 1 / std
stats_dict = {
"mean": mean.tolist(),
"invstddev": invstd.tolist(),
}
with open(args.output_path, "w") as f:
json.dump(stats_dict, f, indent=2)
if __name__ == "__main__":
cli_main()
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from lightning import RNNTModule
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def _eval_subset(model, subset):
total_edit_distance = 0.0
total_length = 0.0
if subset == "dev":
dataloader = model.dev_dataloader()
else:
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval(args):
model = RNNTModule.load_from_checkpoint(
args.checkpoint_path,
tedlium_path=str(args.tedlium_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
reduction="mean",
).eval()
if args.use_cuda:
model = model.to(device="cuda")
_eval_subset(model, "dev")
_eval_subset(model, "test")
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
14.762723922729492,
16.020633697509766,
16.911531448364258,
16.80994415283203,
18.72406005859375,
18.84550666809082,
19.021404266357422,
19.623443603515625,
19.403806686401367,
19.52766990661621,
19.253433227539062,
19.211227416992188,
19.216045379638672,
19.315574645996094,
19.267532348632812,
19.146976470947266,
18.98181915283203,
18.81462287902832,
18.67916488647461,
18.5198917388916,
18.360441207885742,
18.18699836730957,
18.008447647094727,
17.82094955444336,
17.644861221313477,
17.51972007751465,
17.51348876953125,
17.171707153320312,
17.070415496826172,
17.21990394592285,
16.868940353393555,
17.048307418823242,
16.894960403442383,
17.04732322692871,
16.955705642700195,
17.053966522216797,
17.037548065185547,
17.03425407409668,
17.03618621826172,
16.979724884033203,
16.889690399169922,
16.779285430908203,
16.689767837524414,
16.62590789794922,
16.600360870361328,
16.610321044921875,
16.692338943481445,
16.61323356628418,
16.638328552246094,
16.494739532470703,
16.42980194091797,
16.23759651184082,
16.144210815429688,
16.018585205078125,
15.985218048095703,
15.947102546691895,
15.894798278808594,
15.832999229431152,
15.704426765441895,
15.538087844848633,
15.378302574157715,
15.19461441040039,
15.00456714630127,
14.861663818359375,
14.676336288452148,
14.594626426696777,
14.561753273010254,
14.464197158813477,
14.43082046508789,
14.388801574707031,
14.257562637329102,
14.231459617614746,
14.19768238067627,
14.123900413513184,
14.159867286682129,
14.059795379638672,
13.968880653381348,
13.927794456481934,
13.645783424377441,
12.086114883422852
],
"invstddev": [
0.3553205132484436,
0.3363242745399475,
0.3194723129272461,
0.3199574947357178,
0.28755369782447815,
0.2879481613636017,
0.27939942479133606,
0.27543479204177856,
0.2806696891784668,
0.28141146898269653,
0.2753477990627289,
0.274241179227829,
0.27815768122673035,
0.27794352173805237,
0.2763032615184784,
0.2744459807872772,
0.27375343441963196,
0.27415215969085693,
0.27628427743911743,
0.27667510509490967,
0.2806207835674286,
0.28371962904930115,
0.2893684506416321,
0.2944427728652954,
0.2989389896392822,
0.30326008796691895,
0.30760079622268677,
0.3089521527290344,
0.3105863034725189,
0.31274259090423584,
0.31318506598472595,
0.3154853880405426,
0.3167822062969208,
0.3182784914970398,
0.31875282526016235,
0.3185810148715973,
0.31908345222473145,
0.3207632303237915,
0.32282087206840515,
0.3241617977619171,
0.3260948061943054,
0.32735878229141235,
0.32947203516960144,
0.33052706718444824,
0.3309975266456604,
0.3301711678504944,
0.32793518900871277,
0.3252142369747162,
0.32336947321891785,
0.32320502400398254,
0.3264254927635193,
0.32860180735588074,
0.3322647213935852,
0.3100382685661316,
0.3216720223426819,
0.32280418276786804,
0.32710719108581543,
0.3284962773323059,
0.3319654166698456,
0.32880258560180664,
0.33075764775276184,
0.32947179675102234,
0.32880640029907227,
0.3296009302139282,
0.324250727891922,
0.3247823715209961,
0.328702837228775,
0.32418182492256165,
0.3247915208339691,
0.3251509964466095,
0.31811773777008057,
0.3195462226867676,
0.3187839686870575,
0.31459841132164,
0.32190003991127014,
0.3193890154361725,
0.315574049949646,
0.317360520362854,
0.3075887858867645,
0.3034747838973999
]
}
import json
import math
import os
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
from torchaudio.transforms import TimeMasking
from utils import GAIN, piecewise_linear_log, spectrogram_transform
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
def _batch_by_token_count(idx_target_lengths, token_limit):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if target_length == -1:
continue
if current_token_count + target_length > token_limit:
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomDataset(torch.utils.data.Dataset):
r"""Sort samples by target length and batch to max durations."""
def __init__(self, base_dataset, max_token_limit):
super().__init__()
self.base_dataset = base_dataset
idx_target_lengths = [
(idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist)
]
assert len(idx_target_lengths) > 0
idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1])
assert max_token_limit >= idx_target_lengths[-1][1]
self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit)
def _target_length(self, fileid, line):
transcript_path = os.path.join(self.base_dataset._path, "stm", fileid)
with open(transcript_path + ".stm") as f:
transcript = f.readlines()[line]
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6)
if transcript.lower() == "ignore_time_segment_in_scoring\n":
return -1
else:
return float(end_time) - float(start_time)
def __getitem__(self, idx):
return [self.base_dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
return nbest_batch
class RNNTModule(LightningModule):
def __init__(
self,
*,
tedlium_path: str,
sp_model_path: str,
global_stats_path: str,
reduction: str,
):
super().__init__()
self.model = emformer_rnnt_base(num_symbols=501)
self.loss = torchaudio.transforms.RNNTLoss(reduction=reduction, clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
TimeMasking(100, 0.2),
TimeMasking(100, 0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.tedlium_path = tedlium_path
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.blank_idx = self.sp_model.get_piece_size()
def _extract_labels(self, samples: List):
"""Convert text transcript into int labels.
Note:
There are ``<unk>`` tokens in the training set that are regarded as normal tokens
by the SentencePiece model. This will impact RNNT decoding since the decoding result
of ``<unk>`` will be ``?? unk ??`` and will not be excluded from the final prediction.
To address it, here we replace ``<unk>`` with ``<garbage>`` and set
``user_defined_symbols=["<garbage>"]`` in the SentencePiece model training.
Then we map the index of ``<garbage>`` to the real ``unknown`` index.
"""
targets = [
self.sp_model.encode(sample[2].lower().replace("<unk>", "<garbage>").replace("\n", ""))
for sample in samples
]
targets = [
[ele if ele != 4 else self.sp_model.unk_id() for ele in target] for target in targets
] # map id of <unk> token to unk_id
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples
def _step(self, batch, batch_idx, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{"scheduler": self.warmup_lr_scheduler, "interval": "step"},
],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._train_collate_fn,
num_workers=10,
shuffle=True,
)
return dataloader
def val_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def dev_dataloader(self):
dataset = torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
from lightning import RNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True,
)
callbacks = [
checkpoint,
train_checkpoint,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
gpus=args.gpus,
accelerator="gpu",
strategy="ddp",
gradient_clip_val=5.0,
callbacks=callbacks,
)
model = RNNTModule(
tedlium_path=str(args.tedlium_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
reduction=args.reduction,
)
trainer.fit(model)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
required=True,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--reduction",
default="mean",
type=str,
help="Reduction option for RNN Transducer loss function." "(Default: ``mean``)",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--num-nodes",
default=1,
type=int,
help="Number of nodes to use for training. (Default: 1)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_train(args)
if __name__ == "__main__":
cli_main()
"""Train the SentencePiece model by using the transcripts of TED-LIUM release 3 training set.
Example:
python train_spm.py --tedlium-path /home/datasets/
"""
import logging
import os
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-dir",
default=pathlib.Path("./"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _extract_train_text(tedlium_path, output_dir):
stm_path = tedlium_path / "TEDLIUM_release-3/data/stm/"
transcripts = []
for file in sorted(os.listdir(stm_path)):
if file.endswith(".stm"):
file = os.path.join(stm_path, file)
with open(file) as f:
for line in f.readlines():
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = line.split(" ", 6)
if transcript == "ignore_time_segment_in_scoring\n":
continue
else:
transcript = transcript.lower().replace("<unk>", "<garbage>")
transcripts.append(transcript)
with open(output_dir / "text_train.txt", "w") as f:
f.writelines(transcripts)
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
_extract_train_text(args.tedlium_path, args.output_dir)
spm.SentencePieceTrainer.train(
input=args.output_dir / "text_train.txt",
vocab_size=500,
model_prefix="spm_bpe_500",
model_type="bpe",
input_sentence_size=100000000,
character_coverage=1.0,
user_defined_symbols=["<garbage>"],
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
logger.info("Successfully trained the sentencepiece model")
if __name__ == "__main__":
cli_main()
../librispeech_emformer_rnnt/utils.py
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment