Commit b0c8e239 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add Conformer RNN-T model prototype (#2322)

Summary:
Adds Conformer RNN-T model as prototype feature, by way of factory functions `conformer_rnnt_model` and `conformer_rnnt_base`, which instantiates a baseline version of the model. Also includes the following:
- Modifies `Conformer` to accept arguments `use_group_norm` and `convolution_first` to pass to each of its `ConformerLayer` instances.
- Makes `_Predictor` an abstract class and introduces `_EmformerEncoder` and `_ConformerEncoder`.
- Introduces tests for `conformer_rnnt_model`.
- Adds docs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2322

Reviewed By: xiaohui-zhang

Differential Revision: D35565987

Pulled By: hwangjeff

fbshipit-source-id: cb37bb0477ae3d5fcf0b7124f334f4cbb89b5789
parent bd319959
......@@ -60,6 +60,7 @@ Prototype API References
prototype
prototype.io
prototype.ctc_decoder
prototype.models
prototype.pipelines
Getting Started
......
torchaudio.prototype.models
===========================
.. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
conformer_rnnt_model
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_model
conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_base
......@@ -19,4 +19,5 @@ imported explicitly, e.g.
.. toctree::
prototype.io
prototype.ctc_decoder
prototype.models
prototype.pipelines
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
class ConformerRNNTFloat32CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class ConformerRNNTFloat64CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
@skipIfNoCuda
class ConformerRNNTFloat32GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class ConformerRNNTFloat64GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype.models import conformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class ConformerRNNTTestImpl(TestBaseMixin):
def _get_input_config(self):
model_config = self._get_model_config()
max_input_length = 59
return {
"batch_size": 7,
"max_input_length": max_input_length,
"num_symbols": model_config["num_symbols"],
"max_target_length": 45,
"input_dim": model_config["input_dim"],
"encoding_dim": model_config["encoding_dim"],
"joiner_max_input_length": max_input_length // model_config["time_reduction_stride"],
"time_reduction_stride": model_config["time_reduction_stride"],
}
def _get_model_config(self):
return {
"input_dim": 80,
"num_symbols": 128,
"encoding_dim": 64,
"symbol_embedding_dim": 32,
"num_lstm_layers": 2,
"lstm_hidden_dim": 11,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 1e-5,
"lstm_dropout": 0.3,
"joiner_activation": "tanh",
"time_reduction_stride": 4,
"conformer_input_dim": 100,
"conformer_ffn_dim": 33,
"conformer_num_layers": 3,
"conformer_num_heads": 4,
"conformer_depthwise_conv_kernel_size": 31,
"conformer_dropout": 0.1,
}
def _get_model(self):
return conformer_rnnt_model(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
def _get_transcriber_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input_dim = input_config["input_dim"]
input = torch.rand(batch_size, max_input_length, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.full((batch_size,), max_input_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_predictor_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
num_symbols = input_config["num_symbols"]
max_target_length = input_config["max_target_length"]
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32)
lengths = torch.full((batch_size,), max_target_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_joiner_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
input_dim = input_config["encoding_dim"]
utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to(
device=self.device, dtype=self.dtype
)
utterance_lengths = torch.randint(0, joiner_max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(device=self.device, dtype=self.dtype)
target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return utterance_encodings, utterance_lengths, target_encodings, target_lengths
def setUp(self):
super().setUp()
torch.random.manual_seed(31)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNT does not change the behavior of method `forward`."""
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt).eval()
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_input_lengths, ref_target_lengths, ref_state = rnnt(
inputs, input_lengths, targets, target_lengths, ref_state
)
(
scripted_out,
scripted_input_lengths,
scripted_target_lengths,
scripted_state,
) = scripted(inputs, input_lengths, targets, target_lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_input_lengths, scripted_input_lengths)
self.assertEqual(ref_target_lengths, scripted_target_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_transcribe(self):
r"""Verify that scripting RNNT does not change the behavior of method `transcribe`."""
input, lengths = self._get_transcriber_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_lengths = rnnt.transcribe(input, lengths)
scripted_out, scripted_lengths = scripted.transcribe(input, lengths)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
def test_torchscript_consistency_predict(self):
r"""Verify that scripting RNNT does not change the behavior of method `predict`."""
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state)
scripted_out, scripted_lengths, scripted_state = scripted.predict(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_join(self):
r"""Verify that scripting RNNT does not change the behavior of method `join`."""
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_src_lengths, ref_tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
scripted_out, scripted_src_lengths, scripted_tgt_lengths = scripted.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_src_lengths, scripted_src_lengths)
self.assertEqual(ref_tgt_lengths, scripted_tgt_lengths)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, target_lengths, state = rnnt(inputs, input_lengths, targets, target_lengths, state)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
self.assertEqual((batch_size,), target_lengths.shape)
def test_output_shape_transcribe(self):
r"""Check that method `transcribe` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input, lengths = self._get_transcriber_input()
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
time_reduction_stride = model_config["time_reduction_stride"]
rnnt = self._get_model()
out, out_lengths = rnnt.transcribe(input, lengths)
self.assertEqual(
(batch_size, max_input_length // time_reduction_stride, encoding_dim),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_predict(self):
r"""Check that method `predict` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_target_length = input_config["max_target_length"]
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, state = rnnt.predict(input, lengths, state)
self.assertEqual((batch_size, max_target_length, encoding_dim), out.shape)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_join(self):
r"""Check that method `join` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
out, src_lengths, tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), src_lengths.shape)
self.assertEqual((batch_size,), tgt_lengths.shape)
......@@ -223,6 +223,10 @@ class Conformer(torch.nn.Module):
num_layers (int): number of Conformer layers to instantiate.
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0)
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
in the convolution module. (Default: ``False``)
convolution_first (bool, optional): apply the convolution module ahead of
the attention module. (Default: ``False``)
Examples:
>>> conformer = Conformer(
......@@ -245,6 +249,8 @@ class Conformer(torch.nn.Module):
num_layers: int,
depthwise_conv_kernel_size: int,
dropout: float = 0.0,
use_group_norm: bool = False,
convolution_first: bool = False,
):
super().__init__()
......@@ -255,7 +261,9 @@ class Conformer(torch.nn.Module):
ffn_dim,
num_heads,
depthwise_conv_kernel_size,
dropout,
dropout=dropout,
use_group_norm=use_group_norm,
convolution_first=convolution_first,
)
for _ in range(num_layers)
]
......
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
......@@ -133,8 +134,23 @@ class _CustomLSTM(torch.nn.Module):
return output, state
class _Transcriber(torch.nn.Module):
r"""Recurrent neural network transducer (RNN-T) transcription network.
class _Transcriber(ABC):
@abstractmethod
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pass
@abstractmethod
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
pass
class _EmformerEncoder(torch.nn.Module, _Transcriber):
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
Args:
input_dim (int): feature dimension of each input sequence element.
......@@ -285,6 +301,7 @@ class _Predictor(torch.nn.Module):
output_dim (int): feature dimension of each output sequence element.
symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_hidden_dim (int): output dimension of each LSTM layer.
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
for LSTM layers. (Default: ``False``)
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
......@@ -299,6 +316,7 @@ class _Predictor(torch.nn.Module):
output_dim: int,
symbol_embedding_dim: int,
num_lstm_layers: int,
lstm_hidden_dim: int,
lstm_layer_norm: bool = False,
lstm_layer_norm_epsilon: float = 1e-5,
lstm_dropout: float = 0.0,
......@@ -309,8 +327,8 @@ class _Predictor(torch.nn.Module):
self.lstm_layers = torch.nn.ModuleList(
[
_CustomLSTM(
symbol_embedding_dim,
symbol_embedding_dim,
symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
lstm_hidden_dim,
layer_norm=lstm_layer_norm,
layer_norm_epsilon=lstm_layer_norm_epsilon,
)
......@@ -318,7 +336,7 @@ class _Predictor(torch.nn.Module):
]
)
self.dropout = torch.nn.Dropout(p=lstm_dropout)
self.linear = torch.nn.Linear(symbol_embedding_dim, output_dim)
self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
self.output_layer_norm = torch.nn.LayerNorm(output_dim)
self.lstm_dropout = lstm_dropout
......@@ -377,7 +395,7 @@ class _Joiner(torch.nn.Module):
Args:
input_dim (int): source and target input dimension.
output_dim (int): output dimension.
activation (str, optional): activation function to use in the joiner
activation (str, optional): activation function to use in the joiner.
Must be one of ("relu", "tanh"). (Default: "relu")
"""
......@@ -729,7 +747,7 @@ def emformer_rnnt_model(
RNNT:
Emformer RNN-T model.
"""
transcriber = _Transcriber(
encoder = _EmformerEncoder(
input_dim=input_dim,
output_dim=encoding_dim,
segment_length=segment_length,
......@@ -751,12 +769,13 @@ def emformer_rnnt_model(
encoding_dim,
symbol_embedding_dim=symbol_embedding_dim,
num_lstm_layers=num_lstm_layers,
lstm_hidden_dim=symbol_embedding_dim,
lstm_layer_norm=lstm_layer_norm,
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
lstm_dropout=lstm_dropout,
)
joiner = _Joiner(encoding_dim, num_symbols)
return RNNT(transcriber, predictor, joiner)
return RNNT(encoder, predictor, joiner)
def emformer_rnnt_base(num_symbols: int) -> RNNT:
......
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [
"conformer_rnnt_base",
"conformer_rnnt_model",
]
from typing import List, Optional, Tuple
import torch
from torchaudio.models import Conformer, RNNT
from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
class _ConformerEncoder(torch.nn.Module, _Transcriber):
def __init__(
self,
*,
input_dim: int,
output_dim: int,
time_reduction_stride: int,
conformer_input_dim: int,
conformer_ffn_dim: int,
conformer_num_layers: int,
conformer_num_heads: int,
conformer_depthwise_conv_kernel_size: int,
conformer_dropout: float,
) -> None:
super().__init__()
self.time_reduction = _TimeReduction(time_reduction_stride)
self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
self.conformer = Conformer(
num_layers=conformer_num_layers,
input_dim=conformer_input_dim,
ffn_dim=conformer_ffn_dim,
num_heads=conformer_num_heads,
depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
dropout=conformer_dropout,
use_group_norm=True,
convolution_first=True,
)
self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
input_linear_out = self.input_linear(time_reduction_out)
x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
output_linear_out = self.output_linear(x)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, lengths
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
raise RuntimeError("Conformer does not support streaming inference.")
def conformer_rnnt_model(
*,
input_dim: int,
encoding_dim: int,
time_reduction_stride: int,
conformer_input_dim: int,
conformer_ffn_dim: int,
conformer_num_layers: int,
conformer_num_heads: int,
conformer_depthwise_conv_kernel_size: int,
conformer_dropout: float,
num_symbols: int,
symbol_embedding_dim: int,
num_lstm_layers: int,
lstm_hidden_dim: int,
lstm_layer_norm: int,
lstm_layer_norm_epsilon: int,
lstm_dropout: int,
joiner_activation: str,
) -> RNNT:
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
Args:
input_dim (int): dimension of input sequence frames passed to transcription network.
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
passed to joint network.
time_reduction_stride (int): factor by which to reduce length of input sequence.
conformer_input_dim (int): dimension of Conformer input.
conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
conformer_num_layers (int): number of Conformer layers to instantiate.
conformer_num_heads (int): number of attention heads in each Conformer layer.
conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
conformer_dropout (float): Conformer dropout probability.
num_symbols (int): cardinality of set of target tokens.
symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_hidden_dim (int): output dimension of each LSTM layer.
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
lstm_dropout (float): LSTM dropout probability.
joiner_activation (str): activation function to use in the joiner.
Must be one of ("relu", "tanh"). (Default: "relu")
Returns:
RNNT:
Conformer RNN-T model.
"""
encoder = _ConformerEncoder(
input_dim=input_dim,
output_dim=encoding_dim,
time_reduction_stride=time_reduction_stride,
conformer_input_dim=conformer_input_dim,
conformer_ffn_dim=conformer_ffn_dim,
conformer_num_layers=conformer_num_layers,
conformer_num_heads=conformer_num_heads,
conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
conformer_dropout=conformer_dropout,
)
predictor = _Predictor(
num_symbols=num_symbols,
output_dim=encoding_dim,
symbol_embedding_dim=symbol_embedding_dim,
num_lstm_layers=num_lstm_layers,
lstm_hidden_dim=lstm_hidden_dim,
lstm_layer_norm=lstm_layer_norm,
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
lstm_dropout=lstm_dropout,
)
joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
return RNNT(encoder, predictor, joiner)
def conformer_rnnt_base() -> RNNT:
r"""Builds basic version of Conformer RNN-T model.
Returns:
RNNT:
Conformer RNN-T model.
"""
return conformer_rnnt_model(
input_dim=80,
encoding_dim=1024,
time_reduction_stride=4,
conformer_input_dim=256,
conformer_ffn_dim=1024,
conformer_num_layers=16,
conformer_num_heads=4,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
num_symbols=1024,
symbol_embedding_dim=256,
num_lstm_layers=2,
lstm_hidden_dim=512,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-5,
lstm_dropout=0.3,
joiner_activation="tanh",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment