Unverified Commit 48cfbf2b authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Introduce Emformer (#1801)

Adds an implementation of Emformer, a memory-efficient transformer architecture 
introduced in https://ieeexplore.ieee.org/document/9414560 that targets low-latency 
streaming speech recognition applications.
parent e3734fef
......@@ -39,6 +39,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
compliance.kaldi
kaldi_io
utils
prototype
.. toctree::
......
.. role:: hidden
:class: hidden-section
torchaudio.prototype.emformer
=============================
.. currentmodule:: torchaudio.prototype.emformer
Emformer is a prototype feature; see `here <https://pytorch.org/audio>`_
for more information on prototype features.
It is available only within nightly builds and must be imported
explicitly, e.g. via ``from torchaudio.prototype.emformer import Emformer``.
Emformer
~~~~~~~~
.. autoclass:: Emformer
.. automethod:: forward
.. automethod:: infer
References
~~~~~~~~~~
.. footbibliography::
......@@ -178,6 +178,13 @@
year={2016},
organization={IEEE}
}
@inproceedings{shi2021emformer,
title={Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition},
author={Shi, Yangyang and Wang, Yongqiang and Wu, Chunyang and Yeh, Ching-Feng and Chan, Julian and Zhang, Frank and Le, Duc and Seltzer, Mike},
booktitle={ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={6783-6787},
year={2021}
}
@article{mises1929praktische,
title={Praktische Verfahren der Gleichungsaufl{\"o}sung.},
author={Mises, RV and Pollaczek-Geiringer, Hilda},
......
import torch
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase
class EmformerFloat32CPUTest(EmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class EmformerFloat64CPUTest(EmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
@skipIfNoCuda
class EmformerFloat32GPUTest(EmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class EmformerFloat64GPUTest(EmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype import Emformer
class EmformerTestImpl(TestBaseMixin):
def _gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
segment_length=4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(
device=self.device, dtype=self.dtype
)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
def test_torchscript_consistency_forward(self):
r"""Verify that scripting Emformer does not change the behavior of method `forward`."""
input_dim = 128
batch_size = 10
num_frames = 400
right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(
input_dim, batch_size, num_frames, right_context_length
)
scripted = torch_script(emformer)
ref_out, ref_len = emformer(input, lengths)
scripted_out, scripted_len = scripted(input, lengths)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_len, scripted_len)
def test_torchscript_consistency_infer(self):
r"""Verify that scripting Emformer does not change the behavior of method `infer`."""
input_dim = 128
batch_size = 10
num_frames = 400
right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length).eval()
scripted = torch_script(emformer).eval()
ref_state, scripted_state = None, None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer(
input, lengths, scripted_state
)
self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_len, scripted_len)
self.assertEqual(ref_state, scripted_state)
def test_output_shape_forward(self):
r"""Check that method `forward` produces correctly-shaped outputs."""
input_dim = 128
batch_size = 10
num_frames = 123
right_context_length = 9
emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(
input_dim, batch_size, num_frames, right_context_length
)
output, output_lengths = emformer(input, lengths)
self.assertEqual(
(batch_size, num_frames - right_context_length, input_dim), output.shape
)
self.assertEqual((batch_size,), output_lengths.shape)
def test_output_shape_infer(self):
r"""Check that method `infer` produces correctly-shaped outputs."""
input_dim = 256
batch_size = 5
num_frames = 200
right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval()
state = None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual(
(batch_size, num_frames - right_context_length, input_dim), output.shape
)
self.assertEqual((batch_size,), output_lengths.shape)
def test_output_lengths_forward(self):
r"""Check that method `forward` returns input `lengths` unmodified."""
input_dim = 88
batch_size = 13
num_frames = 123
right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(
input_dim, batch_size, num_frames, right_context_length
)
_, output_lengths = emformer(input, lengths)
self.assertEqual(lengths, output_lengths)
def test_output_lengths_infer(self):
r"""Check that method `infer` returns input `lengths` with right context length subtracted."""
input_dim = 88
batch_size = 13
num_frames = 123
right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
_, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual(
torch.clamp(lengths - right_context_length, min=0), output_lengths
)
from .emformer import Emformer
__all__ = ["Emformer"]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment