Commit 78ce7010 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add Emformer RNN-T model (#2003)

Summary:
Adds streaming-capable recurrent neural network transducer (RNN-T) model that uses Emformer for its transcription network. Includes two factory functions — one that allows for building a custom model, and one that builds a preconfigured base model.

Pull Request resolved: https://github.com/pytorch/audio/pull/2003

Reviewed By: nateanl

Differential Revision: D32440879

Pulled By: hwangjeff

fbshipit-source-id: 601cb1de368427f25e3b7d120e185960595d2360
parent b4184dc6
.. role:: hidden .. role:: hidden
:class: hidden-section :class: hidden-section
torchaudio.prototype.emformer torchaudio.prototype
============================= ====================
.. py:module:: torchaudio.prototype.emformer .. py:module:: torchaudio.prototype
.. currentmodule:: torchaudio.prototype.emformer .. currentmodule:: torchaudio.prototype
Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_ ``torchaudio.prototype`` provides prototype features;
for more information on prototype features. see `here <https://pytorch.org/audio>`_ for more information on prototype features.
It is available only within nightly builds and must be imported The module is available only within nightly builds and must be imported
explicitly, e.g. via ``from torchaudio.prototype.emformer import Emformer``. explicitly, e.g. ``import torchaudio.prototype``.
Emformer Emformer
~~~~~~~~ ~~~~~~~~
...@@ -22,6 +22,33 @@ Emformer ...@@ -22,6 +22,33 @@ Emformer
.. automethod:: infer .. automethod:: infer
RNNT
~~~~
.. autoclass:: RNNT
.. automethod:: forward
.. automethod:: transcribe_streaming
.. automethod:: transcribe
.. automethod:: predict
.. automethod:: join
emformer_rnnt_base
~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_base
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
References References
~~~~~~~~~~ ~~~~~~~~~~
......
import torch
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase
class RNNTFloat32CPUTest(RNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class RNNTFloat64CPUTest(RNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
@skipIfNoCuda
class RNNTFloat32GPUTest(RNNTTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class RNNTFloat64GPUTest(RNNTTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype.rnnt import emformer_rnnt_model
class RNNTTestImpl(TestBaseMixin):
def _get_input_config(self):
model_config = self._get_model_config()
return {
"batch_size": 8,
"max_input_length": 61,
"num_symbols": model_config["num_symbols"],
"max_target_length": 23,
"input_dim": model_config["input_dim"],
"right_context_length": model_config["right_context_length"],
"encoding_dim": model_config["encoding_dim"],
"joiner_max_input_length": 61 // model_config["time_reduction_stride"],
"segment_length": model_config["segment_length"],
"time_reduction_stride": model_config["time_reduction_stride"],
}
def _get_model_config(self):
return {
"input_dim": 80,
"encoding_dim": 128,
"num_symbols": 256,
"segment_length": 16,
"right_context_length": 4,
"time_reduction_input_dim": 128,
"time_reduction_stride": 4,
"transformer_num_heads": 4,
"transformer_ffn_dim": 64,
"transformer_num_layers": 3,
"transformer_dropout": 0.0,
"transformer_activation": "relu",
"transformer_left_context_length": 30,
"transformer_max_memory_size": 0,
"transformer_weight_init_scale_strategy": "depthwise",
"transformer_tanh_on_mem": True,
"symbol_embedding_dim": 64,
"num_lstm_layers": 2,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 1e-3,
"lstm_dropout": 0.0,
}
def _get_model(self):
return (
emformer_rnnt_model(**self._get_model_config())
.to(device=self.device, dtype=self.dtype)
.eval()
)
def _get_transcriber_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input_dim = input_config["input_dim"]
right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31)
input = torch.rand(
batch_size, max_input_length + right_context_length, input_dim
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return input, lengths
def _get_transcriber_streaming_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
segment_length = input_config["segment_length"]
input_dim = input_config["input_dim"]
right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31)
input = torch.rand(
batch_size, segment_length + right_context_length, input_dim
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(
1, segment_length + right_context_length + 1, (batch_size,)
).to(device=self.device, dtype=torch.int32)
return input, lengths
def _get_predictor_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
num_symbols = input_config["num_symbols"]
max_target_length = input_config["max_target_length"]
torch.random.manual_seed(31)
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(
device=self.device, dtype=torch.int32
)
lengths = torch.randint(1, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return input, lengths
def _get_joiner_input(self):
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
input_dim = input_config["encoding_dim"]
torch.random.manual_seed(31)
utterance_encodings = torch.rand(
batch_size, joiner_max_input_length, input_dim
).to(device=self.device, dtype=self.dtype)
utterance_lengths = torch.randint(
0, joiner_max_input_length + 1, (batch_size,)
).to(device=self.device, dtype=torch.int32)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(
device=self.device, dtype=self.dtype
)
target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return utterance_encodings, utterance_lengths, target_encodings, target_lengths
def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNT does not change the behavior of method `forward`."""
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt).eval()
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_input_lengths, ref_target_lengths, ref_state = rnnt(
inputs, input_lengths, targets, target_lengths, ref_state
)
(
scripted_out,
scripted_input_lengths,
scripted_target_lengths,
scripted_state,
) = scripted(inputs, input_lengths, targets, target_lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_input_lengths, scripted_input_lengths)
self.assertEqual(ref_target_lengths, scripted_target_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_transcribe(self):
r"""Verify that scripting RNNT does not change the behavior of method `transcribe`."""
input, lengths = self._get_transcriber_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_lengths = rnnt.transcribe(input, lengths)
scripted_out, scripted_lengths = scripted.transcribe(input, lengths)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
def test_torchscript_consistency_transcribe_streaming(self):
r"""Verify that scripting RNNT does not change the behavior of method `transcribe_streaming`."""
input, lengths = self._get_transcriber_streaming_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.transcribe_streaming(
input, lengths, ref_state
)
(
scripted_out,
scripted_lengths,
scripted_state,
) = scripted.transcribe_streaming(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_predict(self):
r"""Verify that scripting RNNT does not change the behavior of method `predict`."""
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_state, scripted_state = None, None
for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state)
scripted_out, scripted_lengths, scripted_state = scripted.predict(
input, lengths, scripted_state
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state)
def test_torchscript_consistency_join(self):
r"""Verify that scripting RNNT does not change the behavior of method `join`."""
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
scripted = torch_script(rnnt)
ref_out, ref_src_lengths, ref_tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
scripted_out, scripted_src_lengths, scripted_tgt_lengths = scripted.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_src_lengths, scripted_src_lengths)
self.assertEqual(ref_tgt_lengths, scripted_tgt_lengths)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
inputs, input_lengths = self._get_transcriber_input()
targets, target_lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, target_lengths, state = rnnt(
inputs, input_lengths, targets, target_lengths, state
)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
self.assertEqual((batch_size,), target_lengths.shape)
def test_output_shape_transcribe(self):
r"""Check that method `transcribe` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
input, lengths = self._get_transcriber_input()
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
time_reduction_stride = model_config["time_reduction_stride"]
rnnt = self._get_model()
out, out_lengths = rnnt.transcribe(input, lengths)
self.assertEqual(
(batch_size, max_input_length // time_reduction_stride, encoding_dim),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_transcribe_streaming(self):
r"""Check that method `transcribe_streaming` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
segment_length = input_config["segment_length"]
encoding_dim = input_config["encoding_dim"]
time_reduction_stride = input_config["time_reduction_stride"]
input, lengths = self._get_transcriber_streaming_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, state = rnnt.transcribe_streaming(input, lengths, state)
self.assertEqual(
(batch_size, segment_length // time_reduction_stride, encoding_dim),
out.shape,
)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_predict(self):
r"""Check that method `predict` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_target_length = input_config["max_target_length"]
model_config = self._get_model_config()
encoding_dim = model_config["encoding_dim"]
input, lengths = self._get_predictor_input()
rnnt = self._get_model()
state = None
for _ in range(2):
out, out_lengths, state = rnnt.predict(input, lengths, state)
self.assertEqual((batch_size, max_target_length, encoding_dim), out.shape)
self.assertEqual((batch_size,), out_lengths.shape)
def test_output_shape_join(self):
r"""Check that method `join` produces correctly-shaped outputs."""
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
joiner_max_input_length = input_config["joiner_max_input_length"]
max_target_length = input_config["max_target_length"]
num_symbols = input_config["num_symbols"]
(
utterance_encodings,
utterance_lengths,
target_encodings,
target_lengths,
) = self._get_joiner_input()
rnnt = self._get_model()
out, src_lengths, tgt_lengths = rnnt.join(
utterance_encodings, utterance_lengths, target_encodings, target_lengths
)
self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape,
)
self.assertEqual((batch_size,), src_lengths.shape)
self.assertEqual((batch_size,), tgt_lengths.shape)
from .emformer import Emformer from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
__all__ = ["Emformer"] __all__ = ["Emformer", "RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment