"vscode:/vscode.git/clone" did not exist on "289aa4855449fe9162c8da7e13aaaaef81755a3b"
Commit a7d1b31c authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Revise LibriSpeech Conformer RNN-T recipe (#2535)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2535

Modifies LibriSpeech Conformer RNN-T example recipe to make the Lightning module and datamodule more generic and reusable.

Reviewed By: mthrok

Differential Revision: D36731576

fbshipit-source-id: 4643e86fac78f3c2bacc15f5d385bc7b10f410a2
parent 54eb0991
...@@ -4,7 +4,8 @@ from argparse import ArgumentParser ...@@ -4,7 +4,8 @@ from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from lightning import ConformerRNNTModule, get_data_module from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger() logger = logging.getLogger()
......
import logging import logging
import math import math
from collections import namedtuple
from typing import List, Tuple from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from data_module import LibriSpeechDataModule
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base from torchaudio.prototype.models import conformer_rnnt_base
from transforms import Batch, TestTransform, TrainTransform, ValTransform
logger = logging.getLogger() logger = logging.getLogger()
_expected_spm_vocab_size = 1023 _expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler): class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing. r"""Learning rate scheduler that performs linear warmup and exponential annealing.
...@@ -74,10 +76,10 @@ def post_process_hypos( ...@@ -74,10 +76,10 @@ def post_process_hypos(
class ConformerRNNTModule(LightningModule): class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model_path): def __init__(self, sp_model):
super().__init__() super().__init__()
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size() spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, ( assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of " "The model returned by conformer_rnnt_base expects a SentencePiece model of "
...@@ -169,15 +171,3 @@ class ConformerRNNTModule(LightningModule): ...@@ -169,15 +171,3 @@ class ConformerRNNTModule(LightningModule):
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test") return self._step(batch, batch_idx, "test")
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
from lightning import ConformerRNNTModule, get_data_module from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins import DDPPlugin
from transforms import get_data_module
def run_train(args): def run_train(args):
......
import json import json
import math import math
from collections import namedtuple
from functools import partial from functools import partial
from typing import List from typing import List
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from data_module import LibriSpeechDataModule
from lightning import Batch
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
...@@ -107,3 +105,15 @@ class TestTransform: ...@@ -107,3 +105,15 @@ class TestTransform:
def __call__(self, sample): def __call__(self, sample):
return self.val_transforms([sample]), [sample] return self.val_transforms([sample]), [sample]
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment