Unverified Commit 4d251485 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

[release 0.13] Remove prototype (#2749)

parent 84d8ced9
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase
from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule
class MockMUSTC:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
"sup",
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch(
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield MuSTCRNNTModule(
mustc_path="mustc_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_common_dataloader"),
("test_step", "test_he_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase
from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule
class MockTEDLIUM:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield TEDLIUM3RNNTModule(
tedlium_path="tedlium_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl
class ConvEmformerFloat32CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class ConvEmformerFloat64CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl
@skipIfNoCuda
class ConvEmformerFloat32GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class ConvEmformerFloat64GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype.models.conv_emformer import ConvEmformer
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin
class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = ConvEmformer(
input_dim,
8,
256,
3,
4,
12,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestImpl
class TestAutogradCPUFloat64(AutogradTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .autograd_test_impl import AutogradTestImpl
@skipIfNoCuda
class TestAutogradCUDAFloat64(AutogradTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import TestBaseMixin
class AutogradTestImpl(TestBaseMixin):
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
)
def test_convolve(self, fn):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device, requires_grad=True)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True)
self.assertTrue(gradcheck(fn, (x, y)))
self.assertTrue(gradgradcheck(fn, (x, y)))
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
import torch
import torchaudio.prototype.functional as F
from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase
class BatchConsistencyTest(TorchaudioTestCase):
@nested_params(
[F.convolve, F.fftconvolve],
)
def test_convolve(self, fn):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
actual = fn(x, y)
expected = torch.stack(
[
torch.stack([fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])])
for i in range(leading_dims[0])
]
)
self.assertEqual(expected, actual)
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .functional_test_impl import FunctionalTestImpl
class FunctionalFloat32CPUTest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class FunctionalFloat64CPUTest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_test_impl import FunctionalTestImpl
@skipIfNoCuda
class FunctionalFloat32CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class FunctionalFloat64CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import numpy as np
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class FunctionalTestImpl(TestBaseMixin):
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
)
def test_convolve_numerics(self, leading_dims, lengths):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.convolve(x, y)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy())
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
expected = expected.reshape(leading_dims + (-1,))
self.assertEqual(expected, actual)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
)
def test_fftconvolve_numerics(self, leading_dims, lengths):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.fftconvolve(x, y)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
@nested_params(
[F.convolve, F.fftconvolve],
[(4, 3, 1, 2), (1,)],
[(10, 4), (2, 2, 2)],
)
def test_convolve_input_leading_dim_check(self, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with different leading dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, "Leading dimensions"):
fn(x, y)
def test_add_noise_broadcast(self):
"""Check that add_noise produces correct outputs when broadcasting input dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = F.add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded)
self.assertEqual(expected, actual)
@parameterized.expand(
[((5, 2, 3), (2, 1, 1), (5, 2), (5, 2, 3)), ((2, 1), (5,), (5,), (5,)), ((3,), (5, 2, 3), (2, 1, 1), (5, 2))]
)
def test_add_noise_leading_dim_check(self, waveform_dims, noise_dims, lengths_dims, snr_dims):
"""Check that add_noise properly rejects inputs with different leading dimension lengths."""
L = 51
waveform = torch.rand(*waveform_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*noise_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*lengths_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
F.add_noise(waveform, noise, lengths, snr)
def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, 50, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, lengths, snr)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_test_impl import TorchScriptConsistencyTestImpl
class TorchScriptConsistencyCPUFloat32Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TorchScriptConsistencyCPUFloat64Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .torchscript_consistency_test_impl import TorchScriptConsistencyTestImpl
@skipIfNoCuda
class TorchScriptConsistencyCUDAFloat32Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class TorchScriptConsistencyCUDAFloat64Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class TorchScriptConsistencyTestImpl(TestBaseMixin):
def _assert_consistency(self, func, inputs, shape_only=False):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(device=self.device, dtype=self.dtype)
inputs_.append(i)
ts_func = torch_script(func)
torch.random.manual_seed(40)
output = func(*inputs_)
torch.random.manual_seed(40)
ts_output = ts_func(*inputs_)
if shape_only:
ts_output = ts_output.shape
output = output.shape
self.assertEqual(ts_output, output)
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
)
def test_convolve(self, fn):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y))
def test_add_noise(self):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self._assert_consistency(F.add_noise, (waveform, noise, lengths, snr))
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
class ConformerRNNTFloat32CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class ConformerRNNTFloat64CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
@skipIfNoCuda
class ConformerRNNTFloat32GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class ConformerRNNTFloat64GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype.models import conformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class ConformerRNNTTestImpl(TestBaseMixin):
def _get_input_config(self):
model_config = self._get_model_config()
max_input_length = 59
return {
"batch_size": 7,
"max_input_length": max_input_length,
"num_symbols": model_config["num_symbols"],
"max_target_length": 45,
"input_dim": model_config["input_dim"],
"encoding_dim": model_config["encoding_dim"],
"joiner_max_input_length": max_input_length // model_config["time_reduction_stride"],
"time_reduction_stride": model_config["time_reduction_stride"],
}
def _get_model_config(self):
return {
"input_dim": 80,
"num_symbols": 128,
"encoding_dim": 64,
"symbol_embedding_dim": 32,
"num_lstm_layers": 2,
"lstm_hidden_dim": 11,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 1e-5,
"lstm_dropout": 0.3,
"joiner_activation": "tanh",
"time_reduction_stride": 4,
"conformer_input_dim": 100,
"conformer_ffn_dim": 33,
"conformer_num_layers": 3,
"conformer_num_heads": 4,
"conformer_depthwise_conv_kernel_size": 31,
"conformer_dropout": 0.1,
}
def _get_model(self):
return conformer_rnnt_model(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
def _get_transcriber_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input_dim = input_config["input_dim"]
input = torch.rand(batch_size, max_input_length, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.full((batch_size,), max_input_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_predictor_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
num_symbols = input_config["num_symbols"]
max_target_length = input_config["max_target_length"]
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32)
lengths = torch.full((batch_size,), max_target_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_joiner_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
input_dim = input_config["encoding_dim"]
utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to(
device=self.device, dtype=self.dtype
)
utterance_lengths = torch.randint(0, joiner_max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(device=self.device, dtype=self.dtype)
target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return utterance_encodings, utterance_lengths, target_encodings, target_lengths
def setUp(self):
super().setUp()
torch.random.manual_seed(31)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNT does not change the behavior of method `forward`."""
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt).eval()
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_input_lengths, ref_target_lengths, ref_state = rnnt(
inputs, input_lengths, targets, target_lengths, ref_state
)
(
scripted_out,
scripted_input_lengths,
scripted_target_lengths,
scripted_state,
) = scripted(inputs, input_lengths, targets, target_lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_input_lengths, scripted_input_lengths)
self.assertEqual(ref_target_lengths, scripted_target_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_transcribe(self):
r"""Verify that scripting RNNT does not change the behavior of method `transcribe`."""
input, lengths = self._get_transcriber_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_lengths = rnnt.transcribe(input, lengths)
scripted_out, scripted_lengths = scripted.transcribe(input, lengths)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
def test_torchscript_consistency_predict(self):
r"""Verify that scripting RNNT does not change the behavior of method `predict`."""
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state)
scripted_out, scripted_lengths, scripted_state = scripted.predict(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_join(self):
r"""Verify that scripting RNNT does not change the behavior of method `join`."""
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_src_lengths, ref_tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
scripted_out, scripted_src_lengths, scripted_tgt_lengths = scripted.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_src_lengths, scripted_src_lengths)
self.assertEqual(ref_tgt_lengths, scripted_tgt_lengths)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, target_lengths, state = rnnt(inputs, input_lengths, targets, target_lengths, state)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
self.assertEqual((batch_size,), target_lengths.shape)
def test_output_shape_transcribe(self):
r"""Check that method `transcribe` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input, lengths = self._get_transcriber_input()
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
time_reduction_stride = model_config["time_reduction_stride"]
rnnt = self._get_model()
out, out_lengths = rnnt.transcribe(input, lengths)
self.assertEqual(
(batch_size, max_input_length // time_reduction_stride, encoding_dim),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_predict(self):
r"""Check that method `predict` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_target_length = input_config["max_target_length"]
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, state = rnnt.predict(input, lengths, state)
self.assertEqual((batch_size, max_target_length, encoding_dim), out.shape)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_join(self):
r"""Check that method `join` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
out, src_lengths, tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), src_lengths.shape)
self.assertEqual((batch_size,), tgt_lengths.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment