Commit a7d1b31c authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Revise LibriSpeech Conformer RNN-T recipe (#2535)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2535

Modifies LibriSpeech Conformer RNN-T example recipe to make the Lightning module and datamodule more generic and reusable.

Reviewed By: mthrok

Differential Revision: D36731576

fbshipit-source-id: 4643e86fac78f3c2bacc15f5d385bc7b10f410a2
parent 54eb0991
......@@ -4,7 +4,8 @@ from argparse import ArgumentParser
import torch
import torchaudio
from lightning import ConformerRNNTModule, get_data_module
from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger()
......
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from data_module import LibriSpeechDataModule
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base
from transforms import Batch, TestTransform, TrainTransform, ValTransform
logger = logging.getLogger()
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
......@@ -74,10 +76,10 @@ def post_process_hypos(
class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model_path):
def __init__(self, sp_model):
super().__init__()
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
......@@ -169,15 +171,3 @@ class ConformerRNNTModule(LightningModule):
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
import pathlib
from argparse import ArgumentParser
from lightning import ConformerRNNTModule, get_data_module
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from transforms import get_data_module
def run_train(args):
......
import json
import math
from collections import namedtuple
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
from data_module import LibriSpeechDataModule
from lightning import Batch
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
......@@ -107,3 +105,15 @@ class TestTransform:
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment