Commit e6c89731 authored by Ruizhe (Ray) Huang's avatar Ruizhe (Ray) Huang Committed by Facebook GitHub Bot
Browse files

Librispeech RNNT recipe updates for pytorch lightening 2.0 (#3336)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3336

Reviewed By: mthrok

Differential Revision: D47846814

Pulled By: huangruizhe

fbshipit-source-id: dc12362bf243c52222dccadec3176e25e43dd652
parent 3f98fb96
......@@ -12,7 +12,7 @@ To build TorchAudio from source, refer to the [contributing guidelines](https://
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
pip install pytorch-lightning sentencepiece tensorboard
```
## Usage
......@@ -27,7 +27,7 @@ pip install pytorch-lightning sentencepiece
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_unigram_1023.model --epochs 160
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp-dir ./experiments --librispeech-path ./librispeech/ --global-stats-path ./global_stats.json --sp-model-path ./spm_unigram_1023.model --epochs 160
```
### Evaluation
......@@ -36,7 +36,7 @@ srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.
Sample SLURM command:
```
srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=159.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_unigram_1023.model --use_cuda
srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=159.ckpt --librispeech-path ./librispeech/ --sp-model-path ./spm_unigram_1023.model --use-cuda
```
The table below contains WER results for various splits.
......
......@@ -6,7 +6,7 @@ import sentencepiece as spm
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
......@@ -39,9 +39,9 @@ def run_train(args):
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
devices=args.gpus,
accelerator="gpu",
strategy=DDPPlugin(find_unused_parameters=False),
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment