Commit b0c8e239 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add Conformer RNN-T model prototype (#2322)

Summary:
Adds Conformer RNN-T model as prototype feature, by way of factory functions `conformer_rnnt_model` and `conformer_rnnt_base`, which instantiates a baseline version of the model. Also includes the following:
- Modifies `Conformer` to accept arguments `use_group_norm` and `convolution_first` to pass to each of its `ConformerLayer` instances.
- Makes `_Predictor` an abstract class and introduces `_EmformerEncoder` and `_ConformerEncoder`.
- Introduces tests for `conformer_rnnt_model`.
- Adds docs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2322

Reviewed By: xiaohui-zhang

Differential Revision: D35565987

Pulled By: hwangjeff

fbshipit-source-id: cb37bb0477ae3d5fcf0b7124f334f4cbb89b5789
parent bd319959
...@@ -60,6 +60,7 @@ Prototype API References ...@@ -60,6 +60,7 @@ Prototype API References
prototype prototype
prototype.io prototype.io
prototype.ctc_decoder prototype.ctc_decoder
prototype.models
prototype.pipelines prototype.pipelines
Getting Started Getting Started
......
torchaudio.prototype.models
===========================
.. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
conformer_rnnt_model
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_model
conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_base
...@@ -19,4 +19,5 @@ imported explicitly, e.g. ...@@ -19,4 +19,5 @@ imported explicitly, e.g.
.. toctree:: .. toctree::
prototype.io prototype.io
prototype.ctc_decoder prototype.ctc_decoder
prototype.models
prototype.pipelines prototype.pipelines
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
class ConformerRNNTFloat32CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class ConformerRNNTFloat64CPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import ConformerRNNTTestImpl
@skipIfNoCuda
class ConformerRNNTFloat32GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class ConformerRNNTFloat64GPUTest(ConformerRNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype.models import conformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class ConformerRNNTTestImpl(TestBaseMixin):
def _get_input_config(self):
model_config = self._get_model_config()
max_input_length = 59
return {
"batch_size": 7,
"max_input_length": max_input_length,
"num_symbols": model_config["num_symbols"],
"max_target_length": 45,
"input_dim": model_config["input_dim"],
"encoding_dim": model_config["encoding_dim"],
"joiner_max_input_length": max_input_length // model_config["time_reduction_stride"],
"time_reduction_stride": model_config["time_reduction_stride"],
}
def _get_model_config(self):
return {
"input_dim": 80,
"num_symbols": 128,
"encoding_dim": 64,
"symbol_embedding_dim": 32,
"num_lstm_layers": 2,
"lstm_hidden_dim": 11,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 1e-5,
"lstm_dropout": 0.3,
"joiner_activation": "tanh",
"time_reduction_stride": 4,
"conformer_input_dim": 100,
"conformer_ffn_dim": 33,
"conformer_num_layers": 3,
"conformer_num_heads": 4,
"conformer_depthwise_conv_kernel_size": 31,
"conformer_dropout": 0.1,
}
def _get_model(self):
return conformer_rnnt_model(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
def _get_transcriber_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input_dim = input_config["input_dim"]
input = torch.rand(batch_size, max_input_length, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.full((batch_size,), max_input_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_predictor_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
num_symbols = input_config["num_symbols"]
max_target_length = input_config["max_target_length"]
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32)
lengths = torch.full((batch_size,), max_target_length).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_joiner_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
input_dim = input_config["encoding_dim"]
utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to(
device=self.device, dtype=self.dtype
)
utterance_lengths = torch.randint(0, joiner_max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(device=self.device, dtype=self.dtype)
target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return utterance_encodings, utterance_lengths, target_encodings, target_lengths
def setUp(self):
super().setUp()
torch.random.manual_seed(31)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNT does not change the behavior of method `forward`."""
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt).eval()
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_input_lengths, ref_target_lengths, ref_state = rnnt(
inputs, input_lengths, targets, target_lengths, ref_state
)
(
scripted_out,
scripted_input_lengths,
scripted_target_lengths,
scripted_state,
) = scripted(inputs, input_lengths, targets, target_lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_input_lengths, scripted_input_lengths)
self.assertEqual(ref_target_lengths, scripted_target_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_transcribe(self):
r"""Verify that scripting RNNT does not change the behavior of method `transcribe`."""
input, lengths = self._get_transcriber_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_lengths = rnnt.transcribe(input, lengths)
scripted_out, scripted_lengths = scripted.transcribe(input, lengths)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
def test_torchscript_consistency_predict(self):
r"""Verify that scripting RNNT does not change the behavior of method `predict`."""
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state)
scripted_out, scripted_lengths, scripted_state = scripted.predict(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_join(self):
r"""Verify that scripting RNNT does not change the behavior of method `join`."""
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_src_lengths, ref_tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
scripted_out, scripted_src_lengths, scripted_tgt_lengths = scripted.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_src_lengths, scripted_src_lengths)
self.assertEqual(ref_tgt_lengths, scripted_tgt_lengths)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, target_lengths, state = rnnt(inputs, input_lengths, targets, target_lengths, state)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
self.assertEqual((batch_size,), target_lengths.shape)
def test_output_shape_transcribe(self):
r"""Check that method `transcribe` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input, lengths = self._get_transcriber_input()
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
time_reduction_stride = model_config["time_reduction_stride"]
rnnt = self._get_model()
out, out_lengths = rnnt.transcribe(input, lengths)
self.assertEqual(
(batch_size, max_input_length // time_reduction_stride, encoding_dim),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_predict(self):
r"""Check that method `predict` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_target_length = input_config["max_target_length"]
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, state = rnnt.predict(input, lengths, state)
self.assertEqual((batch_size, max_target_length, encoding_dim), out.shape)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_join(self):
r"""Check that method `join` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
out, src_lengths, tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), src_lengths.shape)
self.assertEqual((batch_size,), tgt_lengths.shape)
...@@ -223,6 +223,10 @@ class Conformer(torch.nn.Module): ...@@ -223,6 +223,10 @@ class Conformer(torch.nn.Module):
num_layers (int): number of Conformer layers to instantiate. num_layers (int): number of Conformer layers to instantiate.
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0) dropout (float, optional): dropout probability. (Default: 0.0)
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
in the convolution module. (Default: ``False``)
convolution_first (bool, optional): apply the convolution module ahead of
the attention module. (Default: ``False``)
Examples: Examples:
>>> conformer = Conformer( >>> conformer = Conformer(
...@@ -245,6 +249,8 @@ class Conformer(torch.nn.Module): ...@@ -245,6 +249,8 @@ class Conformer(torch.nn.Module):
num_layers: int, num_layers: int,
depthwise_conv_kernel_size: int, depthwise_conv_kernel_size: int,
dropout: float = 0.0, dropout: float = 0.0,
use_group_norm: bool = False,
convolution_first: bool = False,
): ):
super().__init__() super().__init__()
...@@ -255,7 +261,9 @@ class Conformer(torch.nn.Module): ...@@ -255,7 +261,9 @@ class Conformer(torch.nn.Module):
ffn_dim, ffn_dim,
num_heads, num_heads,
depthwise_conv_kernel_size, depthwise_conv_kernel_size,
dropout, dropout=dropout,
use_group_norm=use_group_norm,
convolution_first=convolution_first,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
......
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -133,8 +134,23 @@ class _CustomLSTM(torch.nn.Module): ...@@ -133,8 +134,23 @@ class _CustomLSTM(torch.nn.Module):
return output, state return output, state
class _Transcriber(torch.nn.Module): class _Transcriber(ABC):
r"""Recurrent neural network transducer (RNN-T) transcription network. @abstractmethod
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pass
@abstractmethod
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
pass
class _EmformerEncoder(torch.nn.Module, _Transcriber):
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
Args: Args:
input_dim (int): feature dimension of each input sequence element. input_dim (int): feature dimension of each input sequence element.
...@@ -285,6 +301,7 @@ class _Predictor(torch.nn.Module): ...@@ -285,6 +301,7 @@ class _Predictor(torch.nn.Module):
output_dim (int): feature dimension of each output sequence element. output_dim (int): feature dimension of each output sequence element.
symbol_embedding_dim (int): dimension of each target token embedding. symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate. num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_hidden_dim (int): output dimension of each LSTM layer.
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
for LSTM layers. (Default: ``False``) for LSTM layers. (Default: ``False``)
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
...@@ -299,6 +316,7 @@ class _Predictor(torch.nn.Module): ...@@ -299,6 +316,7 @@ class _Predictor(torch.nn.Module):
output_dim: int, output_dim: int,
symbol_embedding_dim: int, symbol_embedding_dim: int,
num_lstm_layers: int, num_lstm_layers: int,
lstm_hidden_dim: int,
lstm_layer_norm: bool = False, lstm_layer_norm: bool = False,
lstm_layer_norm_epsilon: float = 1e-5, lstm_layer_norm_epsilon: float = 1e-5,
lstm_dropout: float = 0.0, lstm_dropout: float = 0.0,
...@@ -309,8 +327,8 @@ class _Predictor(torch.nn.Module): ...@@ -309,8 +327,8 @@ class _Predictor(torch.nn.Module):
self.lstm_layers = torch.nn.ModuleList( self.lstm_layers = torch.nn.ModuleList(
[ [
_CustomLSTM( _CustomLSTM(
symbol_embedding_dim, symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
symbol_embedding_dim, lstm_hidden_dim,
layer_norm=lstm_layer_norm, layer_norm=lstm_layer_norm,
layer_norm_epsilon=lstm_layer_norm_epsilon, layer_norm_epsilon=lstm_layer_norm_epsilon,
) )
...@@ -318,7 +336,7 @@ class _Predictor(torch.nn.Module): ...@@ -318,7 +336,7 @@ class _Predictor(torch.nn.Module):
] ]
) )
self.dropout = torch.nn.Dropout(p=lstm_dropout) self.dropout = torch.nn.Dropout(p=lstm_dropout)
self.linear = torch.nn.Linear(symbol_embedding_dim, output_dim) self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
self.output_layer_norm = torch.nn.LayerNorm(output_dim) self.output_layer_norm = torch.nn.LayerNorm(output_dim)
self.lstm_dropout = lstm_dropout self.lstm_dropout = lstm_dropout
...@@ -377,7 +395,7 @@ class _Joiner(torch.nn.Module): ...@@ -377,7 +395,7 @@ class _Joiner(torch.nn.Module):
Args: Args:
input_dim (int): source and target input dimension. input_dim (int): source and target input dimension.
output_dim (int): output dimension. output_dim (int): output dimension.
activation (str, optional): activation function to use in the joiner activation (str, optional): activation function to use in the joiner.
Must be one of ("relu", "tanh"). (Default: "relu") Must be one of ("relu", "tanh"). (Default: "relu")
""" """
...@@ -729,7 +747,7 @@ def emformer_rnnt_model( ...@@ -729,7 +747,7 @@ def emformer_rnnt_model(
RNNT: RNNT:
Emformer RNN-T model. Emformer RNN-T model.
""" """
transcriber = _Transcriber( encoder = _EmformerEncoder(
input_dim=input_dim, input_dim=input_dim,
output_dim=encoding_dim, output_dim=encoding_dim,
segment_length=segment_length, segment_length=segment_length,
...@@ -751,12 +769,13 @@ def emformer_rnnt_model( ...@@ -751,12 +769,13 @@ def emformer_rnnt_model(
encoding_dim, encoding_dim,
symbol_embedding_dim=symbol_embedding_dim, symbol_embedding_dim=symbol_embedding_dim,
num_lstm_layers=num_lstm_layers, num_lstm_layers=num_lstm_layers,
lstm_hidden_dim=symbol_embedding_dim,
lstm_layer_norm=lstm_layer_norm, lstm_layer_norm=lstm_layer_norm,
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon, lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
lstm_dropout=lstm_dropout, lstm_dropout=lstm_dropout,
) )
joiner = _Joiner(encoding_dim, num_symbols) joiner = _Joiner(encoding_dim, num_symbols)
return RNNT(transcriber, predictor, joiner) return RNNT(encoder, predictor, joiner)
def emformer_rnnt_base(num_symbols: int) -> RNNT: def emformer_rnnt_base(num_symbols: int) -> RNNT:
......
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [
"conformer_rnnt_base",
"conformer_rnnt_model",
]
from typing import List, Optional, Tuple
import torch
from torchaudio.models import Conformer, RNNT
from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
class _ConformerEncoder(torch.nn.Module, _Transcriber):
def __init__(
self,
*,
input_dim: int,
output_dim: int,
time_reduction_stride: int,
conformer_input_dim: int,
conformer_ffn_dim: int,
conformer_num_layers: int,
conformer_num_heads: int,
conformer_depthwise_conv_kernel_size: int,
conformer_dropout: float,
) -> None:
super().__init__()
self.time_reduction = _TimeReduction(time_reduction_stride)
self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
self.conformer = Conformer(
num_layers=conformer_num_layers,
input_dim=conformer_input_dim,
ffn_dim=conformer_ffn_dim,
num_heads=conformer_num_heads,
depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
dropout=conformer_dropout,
use_group_norm=True,
convolution_first=True,
)
self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
input_linear_out = self.input_linear(time_reduction_out)
x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
output_linear_out = self.output_linear(x)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, lengths
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
raise RuntimeError("Conformer does not support streaming inference.")
def conformer_rnnt_model(
*,
input_dim: int,
encoding_dim: int,
time_reduction_stride: int,
conformer_input_dim: int,
conformer_ffn_dim: int,
conformer_num_layers: int,
conformer_num_heads: int,
conformer_depthwise_conv_kernel_size: int,
conformer_dropout: float,
num_symbols: int,
symbol_embedding_dim: int,
num_lstm_layers: int,
lstm_hidden_dim: int,
lstm_layer_norm: int,
lstm_layer_norm_epsilon: int,
lstm_dropout: int,
joiner_activation: str,
) -> RNNT:
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
Args:
input_dim (int): dimension of input sequence frames passed to transcription network.
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
passed to joint network.
time_reduction_stride (int): factor by which to reduce length of input sequence.
conformer_input_dim (int): dimension of Conformer input.
conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
conformer_num_layers (int): number of Conformer layers to instantiate.
conformer_num_heads (int): number of attention heads in each Conformer layer.
conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
conformer_dropout (float): Conformer dropout probability.
num_symbols (int): cardinality of set of target tokens.
symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_hidden_dim (int): output dimension of each LSTM layer.
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
lstm_dropout (float): LSTM dropout probability.
joiner_activation (str): activation function to use in the joiner.
Must be one of ("relu", "tanh"). (Default: "relu")
Returns:
RNNT:
Conformer RNN-T model.
"""
encoder = _ConformerEncoder(
input_dim=input_dim,
output_dim=encoding_dim,
time_reduction_stride=time_reduction_stride,
conformer_input_dim=conformer_input_dim,
conformer_ffn_dim=conformer_ffn_dim,
conformer_num_layers=conformer_num_layers,
conformer_num_heads=conformer_num_heads,
conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
conformer_dropout=conformer_dropout,
)
predictor = _Predictor(
num_symbols=num_symbols,
output_dim=encoding_dim,
symbol_embedding_dim=symbol_embedding_dim,
num_lstm_layers=num_lstm_layers,
lstm_hidden_dim=lstm_hidden_dim,
lstm_layer_norm=lstm_layer_norm,
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
lstm_dropout=lstm_dropout,
)
joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
return RNNT(encoder, predictor, joiner)
def conformer_rnnt_base() -> RNNT:
r"""Builds basic version of Conformer RNN-T model.
Returns:
RNNT:
Conformer RNN-T model.
"""
return conformer_rnnt_model(
input_dim=80,
encoding_dim=1024,
time_reduction_stride=4,
conformer_input_dim=256,
conformer_ffn_dim=1024,
conformer_num_layers=16,
conformer_num_heads=4,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
num_symbols=1024,
symbol_embedding_dim=256,
num_lstm_layers=2,
lstm_hidden_dim=512,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-5,
lstm_dropout=0.3,
joiner_activation="tanh",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment