import pathlib from argparse import ArgumentParser from lightning import ConformerRNNTModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.plugins import DDPPlugin def run_train(args): checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) train_checkpoint = ModelCheckpoint( checkpoint_dir, monitor="Losses/train_loss", mode="min", save_top_k=5, save_weights_only=False, verbose=True, ) lr_monitor = LearningRateMonitor(logging_interval="step") callbacks = [ checkpoint, train_checkpoint, lr_monitor, ] trainer = Trainer( default_root_dir=args.exp_dir, max_epochs=args.epochs, num_nodes=args.nodes, gpus=args.gpus, accelerator="gpu", strategy=DDPPlugin(find_unused_parameters=False), callbacks=callbacks, reload_dataloaders_every_n_epochs=1, ) model = ConformerRNNTModule( librispeech_path=str(args.librispeech_path), sp_model_path=str(args.sp_model_path), global_stats_path=str(args.global_stats_path), ) trainer.fit(model) def cli_main(): parser = ArgumentParser() parser.add_argument( "--exp-dir", default=pathlib.Path("./exp"), type=pathlib.Path, help="Directory to save checkpoints and logs to. (Default: './exp')", ) parser.add_argument( "--global-stats-path", default=pathlib.Path("global_stats.json"), type=pathlib.Path, help="Path to JSON file containing feature means and stddevs.", ) parser.add_argument( "--librispeech-path", type=pathlib.Path, help="Path to LibriSpeech datasets.", ) parser.add_argument( "--sp-model-path", type=pathlib.Path, help="Path to SentencePiece model.", ) parser.add_argument( "--nodes", default=4, type=int, help="Number of nodes to use for training. (Default: 4)", ) parser.add_argument( "--gpus", default=8, type=int, help="Number of GPUs per node to use for training. (Default: 8)", ) parser.add_argument( "--epochs", default=120, type=int, help="Number of epochs to train for. (Default: 120)", ) args = parser.parse_args() run_train(args) if __name__ == "__main__": cli_main()