Commit 4d0095a5 authored by nateanl's avatar nateanl Committed by Facebook GitHub Bot
Browse files

Add training recipe for Emformer RNNT trained on MuST-C release v2.0 dataset (#2219)

Summary:
- Add a MUSTC dataset under examples
- Add a lightning module for MuST-C dataset
- Refactor `train.py`, `eval.py`, and `global_stats.py` scripts

Pull Request resolved: https://github.com/pytorch/audio/pull/2219

Reviewed By: hwangjeff

Differential Revision: D34180466

Pulled By: nateanl

fbshipit-source-id: 9fc74ce7527da1a81dd0738e124428f9d516d164
parent 825a5976
......@@ -25,12 +25,12 @@ Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 d
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model_type librispeech --exp_dir ./experiments --dataset_path ./datasets/librispeech --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_4096.model
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --model-type librispeech --exp-dir ./experiments --dataset-path ./datasets/librispeech --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_4096.model
```
Sample SLURM command for evaluation:
```
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset_path ./datasets/librispeech --sp_model_path ./spm_bpe_4096.model --use_cuda
srun python eval.py --model_type librispeech --checkpoint_path ./experiments/checkpoints/epoch=119-step=208079.ckpt --dataset-path ./datasets/librispeech --sp-model-path ./spm_bpe_4096.model --use-cuda
```
The script used for training the SentencePiece model that's referenced by the training command above can be found at [`librispeech/train_spm.py`](./librispeech/train_spm.py); a pretrained SentencePiece model can be downloaded [here](https://download.pytorch.org/torchaudio/pipeline-assets/spm_bpe_4096_librispeech.model).
......@@ -52,12 +52,12 @@ Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model_type tedlium3 --exp_dir ./experiments --dataset_path ./datasets/tedlium --global_stats_path ./global_stats.json --sp_model_path ./spm_bpe_500.model --gradient_clip_val 5.0
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model-type tedlium3 --exp-dir ./experiments --dataset-path ./datasets/tedlium --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_500.model --num-nodes 1 --gradient-clip-val 5.0
```
Sample SLURM command for evaluation:
```
srun python eval.py --model_type tedlium3 --checkpoint_path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium_path ./datasets/tedlium --sp_model_path ./spm-bpe-500.model --use_cuda
srun python eval.py --model-type tedlium3 --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --dataset-path ./datasets/tedlium --sp-model-path ./spm_bpe_500.model --use-cuda
```
The table below contains WER results for dev and test subsets of TED-LIUM release 3.
......@@ -68,3 +68,25 @@ The table below contains WER results for dev and test subsets of TED-LIUM releas
| test | 0.098 |
[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table.
### MuST-C release v2.0
The MuST-C model is configured with a vocabulary size of 500. Consequently, the MuST-C model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol). In contrast to those of the datasets for the above two models, MuST-C's transcripts are cased and punctuated; we preserve the casing and punctuation when training the SentencePiece model.
Sample SLURM command for training:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 1 --ntasks-per-node=8 python train.py --model-type mustc --exp-dir ./experiments --dataset-path ./datasets/mustc --global-stats-path ./global_stats.json --sp-model-path ./spm_bpe_500.model --num-nodes 1 --gradient-clip-val 5.0
```
Sample SLURM command for evaluation:
```
srun python eval.py --model-type mustc --checkpoint-path ./experiments/checkpoints/epoch=55-step=106679.ckpt --dataset-path ./datasets/mustc --sp-model-path ./spm_bpe_500.model --use-cuda
```
The table below contains WER results for dev, tst-COMMON, tst-HE subsets of MuST-C release v2.0.
| | WER |
|:-----------------:|-------------:|
| dev | 0.190 |
| tst-COMMON | 0.213 |
| tst-HE | 0.186 |
......@@ -11,6 +11,7 @@ from torchaudio.models import Hypothesis
MODEL_TYPE_LIBRISPEECH = "librispeech"
MODEL_TYPE_TEDLIUM3 = "tedlium3"
MODEL_TYPE_MUSTC = "mustc"
DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
......
......@@ -4,8 +4,9 @@ from argparse import ArgumentParser
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC
from librispeech.lightning import LibriSpeechRNNTModule
from mustc.lightning import MuSTCRNNTModule
from tedlium3.lightning import TEDLIUM3RNNTModule
......@@ -16,10 +17,9 @@ def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(model):
def run_eval_subset(model, dataloader, subset):
total_edit_distance = 0
total_length = 0
dataloader = model.test_dataloader()
with torch.no_grad():
for idx, (batch, transcripts) in enumerate(dataloader):
actual = transcripts[0]
......@@ -28,7 +28,27 @@ def run_eval(model):
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")
logger.info(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval(model, model_type):
if model_type == MODEL_TYPE_LIBRISPEECH:
dataloader = model.test_dataloader()
run_eval_subset(model, dataloader, "test")
elif model_type == MODEL_TYPE_TEDLIUM3:
dev_loader = model.dev_dataloader()
test_loader = model.test_dataloader()
run_eval_subset(model, dev_loader, "dev")
run_eval_subset(model, test_loader, "test")
elif model_type == MODEL_TYPE_MUSTC:
dev_loader = model.dev_dataloader()
test_common_loader = model.test_common_dataloader()
test_he_loader = model.test_he_dataloader()
run_eval_subset(model, dev_loader, "dev")
run_eval_subset(model, test_common_loader, "tst-COMMON")
run_eval_subset(model, test_he_loader, "tst-HE")
else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
def get_lightning_module(args):
......@@ -46,36 +66,45 @@ def get_lightning_module(args):
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_MUSTC:
return MuSTCRNNTModule.load_from_checkpoint(
args.checkpoint_path,
mustc_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--checkpoint_path",
"--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True
)
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global_stats_path",
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--dataset_path",
"--dataset-path",
type=pathlib.Path,
help="Path to dataset.",
)
parser.add_argument(
"--sp_model_path",
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use_cuda",
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
......@@ -96,7 +125,7 @@ def cli_main():
model = get_lightning_module(args)
if args.use_cuda:
model = model.to(device="cuda")
run_eval(model)
run_eval(model, args.model_type)
if __name__ == "__main__":
......
......@@ -11,23 +11,33 @@ from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
from common import (
MODEL_TYPE_LIBRISPEECH,
MODEL_TYPE_TEDLIUM3,
MODEL_TYPE_MUSTC,
piecewise_linear_log,
spectrogram_transform,
)
from must.dataset import MUSTC
logger = logging.getLogger()
def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--dataset_path",
"--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True
)
parser.add_argument(
"--dataset-path",
required=True,
type=pathlib.Path,
help="Path to dataset. "
"For LibriSpeech, all of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
)
parser.add_argument(
"--output_path",
"--output-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './global_stats.json')",
......@@ -68,6 +78,8 @@ def get_dataset(args):
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return torchaudio.datasets.TEDLIUM(args.dataset_path, release="release3", subset="train")
elif args.model_type == MODEL_TYPE_MUSTC:
return MUSTC(args.dataset_path, subset="train")
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
......
from pathlib import Path
from typing import Union
import torch
import torchaudio
import yaml
FOLDER_IN_ARCHIVE = "en-de"
SAMPLE_RATE = 16000
class MUSTC(torch.utils.data.Dataset):
def __init__(
self,
root: Union[str, Path],
folder_in_archive: str = FOLDER_IN_ARCHIVE,
language: str = "en",
subset: str = "train",
):
root = Path(root)
data_dir = root / folder_in_archive / "data" / subset
wav_dir = data_dir / "wav"
yaml_path = data_dir / "txt" / f"{subset}.yaml"
trans_path = data_dir / "txt" / f"{subset}.{language}"
with open(yaml_path, "r") as stream:
file_list = yaml.safe_load(stream)
with open(trans_path, "r") as f:
self.trans_list = f.readlines()
assert len(file_list) == len(self.trans_list)
self.idx_target_lengths = []
self.wav_list = []
for idx, item in enumerate(file_list):
offset = int(item["offset"] * SAMPLE_RATE)
duration = int(item["duration"] * SAMPLE_RATE)
self.idx_target_lengths.append((idx, item["duration"]))
file_path = wav_dir / item["wav"]
self.wav_list.append((file_path, offset, duration))
def _get_mustc_item(self, idx):
file_path, offset, duration = self.wav_list[idx]
waveform, sr = torchaudio.load(file_path, frame_offset=offset, num_frames=duration)
assert sr == SAMPLE_RATE
transcript = self.trans_list[idx].replace("\n", "")
return (waveform, transcript)
def __getitem__(self, idx):
return self._get_mustc_item(idx)
def __len__(self):
return len(self.wav_list)
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from common import (
Batch,
FunctionalModule,
GlobalStatsNormalization,
WarmupLR,
batch_by_token_count,
piecewise_linear_log,
post_process_hypos,
spectrogram_transform,
)
from pytorch_lightning import LightningModule
from torchaudio.models import RNNTBeamSearch, emformer_rnnt_base
from .dataset import MUSTC
class CustomDataset(torch.utils.data.Dataset):
r"""Sort samples by target length and batch to max token count."""
def __init__(self, base_dataset, max_token_limit, max_len):
super().__init__()
self.base_dataset = base_dataset
idx_target_lengths = self.base_dataset.idx_target_lengths
idx_target_lengths = [ele for ele in idx_target_lengths if ele[1] <= max_len]
idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1])
self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)
def __getitem__(self, idx):
return [self.base_dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class MuSTCRNNTModule(LightningModule):
def __init__(
self,
*,
mustc_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__()
self.model = emformer_rnnt_base(num_symbols=501)
self.loss = torchaudio.transforms.RNNTLoss(reduction="mean", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.mustc_path = mustc_path
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.blank_idx = self.sp_model.get_piece_size()
def _extract_labels(self, samples: List):
"""Convert text transcript into int labels."""
targets = [self.sp_model.encode(sample[1]) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), [sample[1] for sample in samples]
def _step(self, batch, batch_idx, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{"scheduler": self.warmup_lr_scheduler, "interval": "step"},
],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch_tuple, batch_idx):
return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self):
dataset = CustomDataset(MUSTC(self.mustc_path, subset="train"), 100, 20)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._train_collate_fn,
num_workers=10,
shuffle=True,
)
return dataloader
def val_dataloader(self):
dataset = CustomDataset(MUSTC(self.mustc_path, subset="dev"), 100, 20)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_common_dataloader(self):
dataset = MUSTC(self.mustc_path, subset="tst-COMMON")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def test_he_dataloader(self):
dataset = MUSTC(self.mustc_path, subset="tst-HE")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def dev_dataloader(self):
dataset = MUSTC(self.mustc_path, subset="dev")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
"""Train the SentencePiece model by using the transcripts of MuST-C release v2.0 training set.
Example:
python train_spm.py --mustc-path /home/datasets/
"""
import io
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--mustc-path",
required=True,
type=pathlib.Path,
help="Path to MUST-C dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path("./spm_bpe_500.model"),
type=pathlib.Path,
help="File to save model to. (Default: './spm_bpe_500.model')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=500,
model_type="bpe",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def cli_main():
args = _parse_args()
_init_logger(args.debug)
with open(args.mustc_path / "en-de/data/train/txt/train.en") as f:
lines = [line.replace("\n", "") for line in f]
model = train_spm(lines)
with open(args.output_file, "wb") as f:
f.write(model)
logger.info("Successfully trained the sentencepiece model")
if __name__ == "__main__":
cli_main()
......@@ -2,7 +2,7 @@
Example:
python train_spm.py --tedlium-path /home/datasets/
"""
import io
import logging
import os
import pathlib
......@@ -22,10 +22,10 @@ def _parse_args():
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-dir",
default=pathlib.Path("./"),
"--output-file",
default=pathlib.Path("./spm_bpe_500.model"),
type=pathlib.Path,
help="File to save feature statistics to. (Default: './')",
help="File to save model to. (Default: './spm_bpe_500.model')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
......@@ -43,30 +43,19 @@ def _extract_train_text(tedlium_path, output_dir):
if transcript == "ignore_time_segment_in_scoring\n":
continue
else:
transcript = transcript.lower().replace("<unk>", "<garbage>")
transcript = transcript.replace("<unk>", "<garbage>").replace("\n", "")
transcripts.append(transcript)
with open(output_dir / "text_train.txt", "w") as f:
f.writelines(transcripts)
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
return transcripts
def cli_main():
args = _parse_args()
_init_logger(args.debug)
_extract_train_text(args.tedlium_path, args.output_dir)
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
input=args.output_dir / "text_train.txt",
sentence_iterator=iter(input),
vocab_size=500,
model_prefix="spm_bpe_500",
model_type="bpe",
input_sentence_size=100000000,
input_sentence_size=-1,
character_coverage=1.0,
user_defined_symbols=["<garbage>"],
bos_id=0,
......@@ -74,6 +63,24 @@ def cli_main():
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
transcripts = _extract_train_text(args.tedlium_path, args.output_dir)
model = train_spm(transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
logger.info("Successfully trained the sentencepiece model")
......
......@@ -2,8 +2,9 @@ import logging
import pathlib
from argparse import ArgumentParser
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC
from librispeech.lightning import LibriSpeechRNNTModule
from mustc.lightning import MuSTCRNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tedlium3.lightning import TEDLIUM3RNNTModule
......@@ -56,40 +57,48 @@ def get_lightning_module(args):
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_MUSTC:
return MuSTCRNNTModule(
mustc_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--global_stats_path",
"--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
required=True,
)
parser.add_argument(
"--dataset_path",
"--dataset-path",
type=pathlib.Path,
help="Path to datasets.",
required=True,
)
parser.add_argument(
"--sp_model_path",
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--exp_dir",
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--num_nodes",
"--num-nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
......@@ -107,7 +116,7 @@ def parse_args():
help="Number of epochs to train for. (Default: 120)",
)
parser.add_argument(
"--gradient_clip_val", default=10.0, type=float, help="Value to clip gradient values to. (Default: 10.0)"
"--gradient-clip-val", default=10.0, type=float, help="Value to clip gradient values to. (Default: 10.0)"
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment