Commit b5d77b15 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add unit tests for PyTorch Lightning modules of emformer_rnnt recipes (#2240)

Summary:
- Refactor the current `LibriSpeechRNNTModule`'s unit test.
- Add unit tests for `TEDLIUM3RNNTModule` and `MuSTCRNNTModule`
- Replace the lambda with partial in `TEDLIUM3RNNTModule` to pass the lightning unit test.

Pull Request resolved: https://github.com/pytorch/audio/pull/2240

Reviewed By: mthrok

Differential Revision: D34285195

Pulled By: nateanl

fbshipit-source-id: 4f20749c85ddd25cbb0eafc1733c64212542338f
parent c5c4bbfd
import os import os
from functools import partial
from typing import List from typing import List
import sentencepiece as spm import sentencepiece as spm
...@@ -86,20 +87,20 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -86,20 +87,20 @@ class TEDLIUM3RNNTModule(LightningModule):
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.valid_data_pipeline = torch.nn.Sequential( self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.tedlium_path = tedlium_path self.tedlium_path = tedlium_path
...@@ -197,8 +198,8 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -197,8 +198,8 @@ class TEDLIUM3RNNTModule(LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx): def test_step(self, batch_tuple, batch_idx):
return self._step(batch, batch_idx, "test") return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self): def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100) dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
......
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"): if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule
class MockSentencePieceProcessor:
def __init__(self, *args, **kwargs):
pass
def get_piece_size(self):
return 4096
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockLIBRISPEECH: class MockLIBRISPEECH:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
...@@ -50,23 +31,16 @@ class MockLIBRISPEECH: ...@@ -50,23 +31,16 @@ class MockLIBRISPEECH:
return 10 return 10
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
@contextmanager @contextmanager
def get_lightning_module(): def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=MockSentencePieceProcessor), patch( with patch(
"asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity "sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=4096)
), patch("torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH), patch( ), patch("asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity), patch(
"torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH
), patch(
"asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset "asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
): ):
yield LibriSpeechRNNTModule( yield LibriSpeechRNNTModule(
librispeech_path="librispeech_path", librispeech_path="librispeech_path",
...@@ -80,28 +54,29 @@ def get_lightning_module(): ...@@ -80,28 +54,29 @@ def get_lightning_module():
class TestLibriSpeechRNNTModule(TorchaudioTestCase): class TestLibriSpeechRNNTModule(TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31) torch.random.manual_seed(31)
def test_training_step(self): @parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module: with get_lightning_module() as lightning_module:
train_dataloader = lightning_module.train_dataloader() dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(train_dataloader)) batch = next(iter(dataloader))
lightning_module.training_step(batch, 0) getattr(lightning_module, step_fname)(batch, 0)
def test_validation_step(self): @parameterized.expand(
with get_lightning_module() as lightning_module: [
val_dataloader = lightning_module.val_dataloader() ("val_dataloader",),
batch = next(iter(val_dataloader)) ]
lightning_module.validation_step(batch, 0) )
def test_forward(self, dataloader_fname):
def test_test_step(self):
with get_lightning_module() as lightning_module:
test_dataloader = lightning_module.test_dataloader()
batch = next(iter(test_dataloader))
lightning_module.test_step(batch, 0)
def test_forward(self):
with get_lightning_module() as lightning_module: with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader() dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(val_dataloader)) batch = next(iter(dataloader))
lightning_module(batch) lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule
class MockMUSTC:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
"sup",
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch(
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield MuSTCRNNTModule(
mustc_path="mustc_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_common_dataloader"),
("test_step", "test_he_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule
class MockTEDLIUM:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield TEDLIUM3RNNTModule(
tedlium_path="tedlium_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
class MockSentencePieceProcessor:
def __init__(self, num_symbols, *args, **kwargs):
self.num_symbols = num_symbols
def get_piece_size(self):
return self.num_symbols
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
class MockDataloader:
def __init__(self, base_dataset, batch_size, collate_fn, *args, **kwargs):
self.base_dataset = base_dataset
self.batch_size = batch_size
self.collate_fn = collate_fn
def __iter__(self):
for sample in iter(self.base_dataset):
if self.batch_size == 1:
sample = [sample]
yield self.collate_fn(sample)
def __len__(self):
return len(self.base_dataset)
...@@ -388,7 +388,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR ...@@ -388,7 +388,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py`` and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments. `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
Please refer to :py:class:`RNNTBundle` for usage instructions. Please refer to :py:class:`RNNTBundle` for usage instructions.
""" """
...@@ -44,7 +44,7 @@ EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pip ...@@ -44,7 +44,7 @@ EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pip
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py`` and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/tedlium3_emformer_rnnt>`__ with ``num_symbols=501``. `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with ``num_symbols=501``.
Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions. Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
""" """
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment