Commit bbdbd582 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add unit tests for Emformer RNN-T LibriSpeech recipe (#2216)

Summary:
Adds unit tests for Emformer RNN-T LibriSpeech recipe. Also makes changes to recipe to resolve errors with pickling lambda functions in Windows.

Pull Request resolved: https://github.com/pytorch/audio/pull/2216

Reviewed By: nateanl

Differential Revision: D34171480

Pulled By: hwangjeff

fbshipit-source-id: 5fcebb457051f3041766324863728411180f5e1e
parent 2b991225
......@@ -66,7 +66,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -57,7 +57,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -21,6 +21,7 @@ Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_l
def piecewise_linear_log(x):
x = x * GAIN
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
......
......@@ -11,7 +11,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger()
......@@ -42,7 +42,7 @@ def generate_statistics(samples):
for idx, sample in enumerate(samples):
mel_spec = spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN)
scaled_mel_spec = piecewise_linear_log(mel_spec)
sum = scaled_mel_spec.sum(0)
sq_sum = scaled_mel_spec.pow(2).sum(0)
M = scaled_mel_spec.size(0)
......
import os
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from common import (
GAIN,
Batch,
FunctionalModule,
GlobalStatsNormalization,
......@@ -77,22 +77,22 @@ class LibriSpeechRNNTModule(LightningModule):
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.librispeech_path = librispeech_path
......@@ -174,8 +174,8 @@ class LibriSpeechRNNTModule(LightningModule):
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def test_step(self, batch_tuple, batch_idx):
return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
......
......@@ -5,7 +5,6 @@ import sentencepiece as spm
import torch
import torchaudio
from common import (
GAIN,
Batch,
FunctionalModule,
GlobalStatsNormalization,
......@@ -85,7 +84,7 @@ class TEDLIUM3RNNTModule(LightningModule):
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
......@@ -96,7 +95,7 @@ class TEDLIUM3RNNTModule(LightningModule):
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
......
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "asr", "emformer_rnnt"))
from contextlib import contextmanager
from unittest.mock import patch
import torch
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule
class MockSentencePieceProcessor:
def __init__(self, *args, **kwargs):
pass
def get_piece_size(self):
return 4096
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockLIBRISPEECH:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=MockSentencePieceProcessor), patch(
"asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH), patch(
"asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset
):
yield LibriSpeechRNNTModule(
librispeech_path="librispeech_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestLibriSpeechRNNTModule(TorchaudioTestCase):
def test_training_step(self):
with get_lightning_module() as lightning_module:
train_dataloader = lightning_module.train_dataloader()
batch = next(iter(train_dataloader))
lightning_module.training_step(batch, 0)
def test_validation_step(self):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
lightning_module.validation_step(batch, 0)
def test_test_step(self):
with get_lightning_module() as lightning_module:
test_dataloader = lightning_module.test_dataloader()
batch = next(iter(test_dataloader))
lightning_module.test_step(batch, 0)
def test_forward(self):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
lightning_module(batch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment