"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fa2836c224afea9ab31da218f1c1ce34d5155e5c"
Commit 825a5976 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add SentencePiece model training script for LibriSpeech Emformer RNN-T (#2218)

Summary:
Adds SentencePiece model training script for LibriSpeech Emformer RNN-T example recipe; updates readme with references.

Pull Request resolved: https://github.com/pytorch/audio/pull/2218

Reviewed By: nateanl

Differential Revision: D34177295

Pulled By: hwangjeff

fbshipit-source-id: 9f32805af792fb8c6f834f2812e20104177a6c43
parent 738d2f8e
...@@ -33,7 +33,9 @@ Sample SLURM command for evaluation: ...@@ -33,7 +33,9 @@ Sample SLURM command for evaluation:
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda
``` ```
Using the sample training command above along with a SentencePiece model trained on LibriSpeech with vocab size 4096 and type bpe, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py). The script used for training the SentencePiece model that's referenced by the training command above can be found at [`librispeech/train_spm.py`](./librispeech/train_spm.py); a pretrained SentencePiece model can be downloaded [here](https://download.pytorch.org/torchaudio/pipeline-assets/spm_bpe_4096_librispeech.model).
Using the sample training command above, [`train.py`](./train.py) produces a model with 76.7M parameters (307MB) that achieves an WER of 0.0456 when evaluated on test-clean with [`eval.py`](./eval.py).
The table below contains WER results for various splits. The table below contains WER results for various splits.
......
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=4096,
model_type="bpe",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path("./spm_bpe_4096.model"),
type=pathlib.Path,
help="File to save model to. (Default: './spm_bpe_4096.model')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment