You need to sign in or sign up before continuing.
Commit b5d77b15 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add unit tests for PyTorch Lightning modules of emformer_rnnt recipes (#2240)

Summary:
- Refactor the current `LibriSpeechRNNTModule`'s unit test.
- Add unit tests for `TEDLIUM3RNNTModule` and `MuSTCRNNTModule`
- Replace the lambda with partial in `TEDLIUM3RNNTModule` to pass the lightning unit test.

Pull Request resolved: https://github.com/pytorch/audio/pull/2240

Reviewed By: mthrok

Differential Revision: D34285195

Pulled By: nateanl

fbshipit-source-id: 4f20749c85ddd25cbb0eafc1733c64212542338f
parent c5c4bbfd
import os import os
from functools import partial
from typing import List from typing import List
import sentencepiece as spm import sentencepiece as spm
...@@ -86,20 +87,20 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -86,20 +87,20 @@ class TEDLIUM3RNNTModule(LightningModule):
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.valid_data_pipeline = torch.nn.Sequential( self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.tedlium_path = tedlium_path self.tedlium_path = tedlium_path
...@@ -197,8 +198,8 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -197,8 +198,8 @@ class TEDLIUM3RNNTModule(LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx): def test_step(self, batch_tuple, batch_idx):
return self._step(batch, batch_idx, "test") return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self): def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100) dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
......
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"): if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule
class MockSentencePieceProcessor:
def __init__(self, *args, **kwargs):
pass
def get_piece_size(self):
return 4096
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockLIBRISPEECH: class MockLIBRISPEECH:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
...@@ -50,23 +31,16 @@ class MockLIBRISPEECH: ...@@ -50,23 +31,16 @@ class MockLIBRISPEECH:
return 10 return 10
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
@contextmanager @contextmanager
def get_lightning_module(): def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=MockSentencePieceProcessor), patch( with patch(
"asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity "sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=4096)
), patch("torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH), patch( ), patch("asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity), patch(
"torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH
), patch(
"asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset "asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
): ):
yield LibriSpeechRNNTModule( yield LibriSpeechRNNTModule(
librispeech_path="librispeech_path", librispeech_path="librispeech_path",
...@@ -80,28 +54,29 @@ def get_lightning_module(): ...@@ -80,28 +54,29 @@ def get_lightning_module():
class TestLibriSpeechRNNTModule(TorchaudioTestCase): class TestLibriSpeechRNNTModule(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31) torch.random.manual_seed(31)
def test_training_step(self): @parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module: with get_lightning_module() as lightning_module:
train_dataloader = lightning_module.train_dataloader() dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(train_dataloader)) batch = next(iter(dataloader))
lightning_module.training_step(batch, 0) getattr(lightning_module, step_fname)(batch, 0)
def test_validation_step(self): @parameterized.expand(
with get_lightning_module() as lightning_module: [
val_dataloader = lightning_module.val_dataloader() ("val_dataloader",),
batch = next(iter(val_dataloader)) ]
lightning_module.validation_step(batch, 0) )
def test_forward(self, dataloader_fname):
def test_test_step(self):
with get_lightning_module() as lightning_module:
test_dataloader = lightning_module.test_dataloader()
batch = next(iter(test_dataloader))
lightning_module.test_step(batch, 0)
def test_forward(self):
with get_lightning_module() as lightning_module: with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader() dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(val_dataloader)) batch = next(iter(dataloader))
lightning_module(batch) lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule
class MockMUSTC:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
"sup",
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch(
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield MuSTCRNNTModule(
mustc_path="mustc_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_common_dataloader"),
("test_step", "test_he_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule
class MockTEDLIUM:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield TEDLIUM3RNNTModule(
tedlium_path="tedlium_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
class MockSentencePieceProcessor:
def __init__(self, num_symbols, *args, **kwargs):
self.num_symbols = num_symbols
def get_piece_size(self):
return self.num_symbols
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
class MockDataloader:
def __init__(self, base_dataset, batch_size, collate_fn, *args, **kwargs):
self.base_dataset = base_dataset
self.batch_size = batch_size
self.collate_fn = collate_fn
def __iter__(self):
for sample in iter(self.base_dataset):
if self.batch_size == 1:
sample = [sample]
yield self.collate_fn(sample)
def __len__(self):
return len(self.base_dataset)
...@@ -388,7 +388,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR ...@@ -388,7 +388,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py`` and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments. `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
Please refer to :py:class:`RNNTBundle` for usage instructions. Please refer to :py:class:`RNNTBundle` for usage instructions.
""" """
...@@ -44,7 +44,7 @@ EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pip ...@@ -44,7 +44,7 @@ EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pip
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py`` and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/tedlium3_emformer_rnnt>`__ with ``num_symbols=501``. `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with ``num_symbols=501``.
Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions. Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
""" """
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment