"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "2350968ee61a6f9ca6ecd24aba9db536e814a24c"
Commit bbdbd582 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add unit tests for Emformer RNN-T LibriSpeech recipe (#2216)

Summary:
Adds unit tests for Emformer RNN-T LibriSpeech recipe. Also makes changes to recipe to resolve errors with pickling lambda functions in Windows.

Pull Request resolved: https://github.com/pytorch/audio/pull/2216

Reviewed By: nateanl

Differential Revision: D34171480

Pulled By: hwangjeff

fbshipit-source-id: 5fcebb457051f3041766324863728411180f5e1e
parent 2b991225
...@@ -66,7 +66,7 @@ fi ...@@ -66,7 +66,7 @@ fi
( (
set -x set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -57,7 +57,7 @@ fi ...@@ -57,7 +57,7 @@ fi
( (
set -x set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -21,6 +21,7 @@ Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_l ...@@ -21,6 +21,7 @@ Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_l
def piecewise_linear_log(x): def piecewise_linear_log(x):
x = x * GAIN
x[x > math.e] = torch.log(x[x > math.e]) x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e x[x <= math.e] = x[x <= math.e] / math.e
return x return x
......
...@@ -11,7 +11,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter ...@@ -11,7 +11,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter
import torch import torch
import torchaudio import torchaudio
from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
logger = logging.getLogger() logger = logging.getLogger()
...@@ -42,7 +42,7 @@ def generate_statistics(samples): ...@@ -42,7 +42,7 @@ def generate_statistics(samples):
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
mel_spec = spectrogram_transform(sample[0].squeeze()).transpose(1, 0) mel_spec = spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN) scaled_mel_spec = piecewise_linear_log(mel_spec)
sum = scaled_mel_spec.sum(0) sum = scaled_mel_spec.sum(0)
sq_sum = scaled_mel_spec.pow(2).sum(0) sq_sum = scaled_mel_spec.pow(2).sum(0)
M = scaled_mel_spec.size(0) M = scaled_mel_spec.size(0)
......
import os import os
from functools import partial
from typing import List from typing import List
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from common import ( from common import (
GAIN,
Batch, Batch,
FunctionalModule, FunctionalModule,
GlobalStatsNormalization, GlobalStatsNormalization,
...@@ -77,22 +77,22 @@ class LibriSpeechRNNTModule(LightningModule): ...@@ -77,22 +77,22 @@ class LibriSpeechRNNTModule(LightningModule):
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2), torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.valid_data_pipeline = torch.nn.Sequential( self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
) )
self.librispeech_path = librispeech_path self.librispeech_path = librispeech_path
...@@ -174,8 +174,8 @@ class LibriSpeechRNNTModule(LightningModule): ...@@ -174,8 +174,8 @@ class LibriSpeechRNNTModule(LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx): def test_step(self, batch_tuple, batch_idx):
return self._step(batch, batch_idx, "test") return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self): def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset( dataset = torch.utils.data.ConcatDataset(
......
...@@ -5,7 +5,6 @@ import sentencepiece as spm ...@@ -5,7 +5,6 @@ import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from common import ( from common import (
GAIN,
Batch, Batch,
FunctionalModule, FunctionalModule,
GlobalStatsNormalization, GlobalStatsNormalization,
...@@ -85,7 +84,7 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -85,7 +84,7 @@ class TEDLIUM3RNNTModule(LightningModule):
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27), torchaudio.transforms.FrequencyMasking(27),
...@@ -96,7 +95,7 @@ class TEDLIUM3RNNTModule(LightningModule): ...@@ -96,7 +95,7 @@ class TEDLIUM3RNNTModule(LightningModule):
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
) )
self.valid_data_pipeline = torch.nn.Sequential( self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)), FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path), GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)), FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))), FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
......
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "asr", "emformer_rnnt"))
from contextlib import contextmanager
from unittest.mock import patch
import torch
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule
class MockSentencePieceProcessor:
def __init__(self, *args, **kwargs):
pass
def get_piece_size(self):
return 4096
def encode(self, input):
return [1, 5, 2]
def decode(self, input):
return "hey"
def unk_id(self):
return 0
def eos_id(self):
return 1
def pad_id(self):
return 2
class MockLIBRISPEECH:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset
def __getitem__(self, n: int):
return [self.base_dataset[n]]
def __len__(self):
return len(self.base_dataset)
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=MockSentencePieceProcessor), patch(
"asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH), patch(
"asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset
):
yield LibriSpeechRNNTModule(
librispeech_path="librispeech_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestLibriSpeechRNNTModule(TorchaudioTestCase):
def test_training_step(self):
with get_lightning_module() as lightning_module:
train_dataloader = lightning_module.train_dataloader()
batch = next(iter(train_dataloader))
lightning_module.training_step(batch, 0)
def test_validation_step(self):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
lightning_module.validation_step(batch, 0)
def test_test_step(self):
with get_lightning_module() as lightning_module:
test_dataloader = lightning_module.test_dataloader()
batch = next(iter(test_dataloader))
lightning_module.test_step(batch, 0)
def test_forward(self):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
lightning_module(batch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment