"setup.py" did not exist on "c2b62b7ffe6aecc6dde4ecf90ebb7ee5e64db565"
Commit 2c79b55a authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add ConvEmformer module (#2358)

Summary:
Adds an implementation of the convolution-augmented streaming transformer (effectively Emformer with convolution block) described in https://arxiv.org/abs/2110.05241.

Continuation of https://github.com/pytorch/audio/issues/2324.

Pull Request resolved: https://github.com/pytorch/audio/pull/2358

Reviewed By: nateanl, xiaohui-zhang

Differential Revision: D36137992

Pulled By: hwangjeff

fbshipit-source-id: 9c7a7c233944fe9ef15b9ba397d7f0809da1f063
parent 2f4eb4ac
...@@ -13,3 +13,17 @@ conformer_rnnt_base ...@@ -13,3 +13,17 @@ conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_base .. autofunction:: conformer_rnnt_base
ConvEmformer
~~~~~~~~~~~~
.. autoclass:: ConvEmformer
.. automethod:: forward
.. automethod:: infer
References
~~~~~~~~~~
.. footbibliography::
...@@ -238,6 +238,15 @@ ...@@ -238,6 +238,15 @@
pages={6783-6787}, pages={6783-6787},
year={2021} year={2021}
} }
@inproceedings{9747706,
author={Shi, Yangyang and Wu, Chunyang and Wang, Dilin and Xiao, Alex and Mahadeokar, Jay and Zhang, Xiaohui and Liu, Chunxi and Li, Ke and Shangguan, Yuan and Nagaraja, Varun and Kalinli, Ozlem and Seltzer, Mike},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution},
year={2022},
volume={},
number={},
pages={8277-8281},
doi={10.1109/ICASSP43922.2022.9747706}}
@article{mises1929praktische, @article{mises1929praktische,
title={Praktische Verfahren der Gleichungsaufl{\"o}sung.}, title={Praktische Verfahren der Gleichungsaufl{\"o}sung.},
author={Mises, RV and Pollaczek-Geiringer, Hilda}, author={Mises, RV and Pollaczek-Geiringer, Hilda},
......
from abc import ABC, abstractmethod
import torch import torch
from torchaudio.models import Emformer from torchaudio.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class EmformerTestImpl(TestBaseMixin): class EmformerTestMixin(ABC):
def _gen_model(self, input_dim, right_context_length): @abstractmethod
emformer = Emformer( def gen_model(self, input_dim, right_context_length):
input_dim, pass
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length): @abstractmethod
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype) def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to( pass
device=self.device, dtype=self.dtype
)
return input, lengths
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -35,8 +25,8 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -35,8 +25,8 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 400 num_frames = 400
right_context_length = 1 right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length) emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
scripted = torch_script(emformer) scripted = torch_script(emformer)
ref_out, ref_len = emformer(input, lengths) ref_out, ref_len = emformer(input, lengths)
...@@ -52,12 +42,12 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -52,12 +42,12 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 5 num_frames = 5
right_context_length = 1 right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length).eval() emformer = self.gen_model(input_dim, right_context_length).eval()
scripted = torch_script(emformer).eval() scripted = torch_script(emformer).eval()
ref_state, scripted_state = None, None ref_state, scripted_state = None, None
for _ in range(3): for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state) ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state) scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out) self.assertEqual(ref_out, scripted_out)
...@@ -71,8 +61,8 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -71,8 +61,8 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 123 num_frames = 123
right_context_length = 9 right_context_length = 9
emformer = self._gen_model(input_dim, right_context_length) emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
output, output_lengths = emformer(input, lengths) output, output_lengths = emformer(input, lengths)
...@@ -86,11 +76,11 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -86,11 +76,11 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 6 num_frames = 6
right_context_length = 2 right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval() emformer = self.gen_model(input_dim, right_context_length).eval()
state = None state = None
for _ in range(3): for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
output, output_lengths, state = emformer.infer(input, lengths, state) output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape) self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
self.assertEqual((batch_size,), output_lengths.shape) self.assertEqual((batch_size,), output_lengths.shape)
...@@ -102,8 +92,8 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -102,8 +92,8 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 123 num_frames = 123
right_context_length = 2 right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length) emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths = emformer(input, lengths) _, output_lengths = emformer(input, lengths)
self.assertEqual(lengths, output_lengths) self.assertEqual(lengths, output_lengths)
...@@ -114,7 +104,29 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -114,7 +104,29 @@ class EmformerTestImpl(TestBaseMixin):
num_frames = 6 num_frames = 6
right_context_length = 2 right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval() emformer = self.gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length) input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths, _ = emformer.infer(input, lengths) _, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths) self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths)
class EmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl
class ConvEmformerFloat32CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class ConvEmformerFloat64CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl
@skipIfNoCuda
class ConvEmformerFloat32GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class ConvEmformerFloat64GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype.models.conv_emformer import ConvEmformer
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin
class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = ConvEmformer(
input_dim,
8,
256,
3,
4,
12,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
...@@ -586,53 +586,14 @@ class _EmformerLayer(torch.nn.Module): ...@@ -586,53 +586,14 @@ class _EmformerLayer(torch.nn.Module):
return output_utterance, output_right_context, output_state, output_mems return output_utterance, output_right_context, output_state, output_mems
class Emformer(torch.nn.Module): class _EmformerImpl(torch.nn.Module):
r"""Implements the Emformer architecture introduced in
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
[:footcite:`shi2021emformer`].
Args:
input_dim (int): input dimension.
num_heads (int): number of attention heads in each Emformer layer.
ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
num_layers (int): number of Emformer layers to instantiate.
segment_length (int): length of each input segment.
dropout (float, optional): dropout probability. (Default: 0.0)
activation (str, optional): activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
right_context_length (int, optional): length of right context. (Default: 0)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
Examples:
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
>>> lengths = torch.randint(1, 200, (128,)) # batch
>>> output = emformer(input, lengths)
>>> input = torch.rand(128, 5, 512)
>>> lengths = torch.ones(128) * 5
>>> output, lengths, states = emformer.infer(input, lengths, None)
"""
def __init__( def __init__(
self, self,
input_dim: int, emformer_layers: torch.nn.ModuleList,
num_heads: int,
ffn_dim: int,
num_layers: int,
segment_length: int, segment_length: int,
dropout: float = 0.0,
activation: str = "relu",
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
): ):
super().__init__() super().__init__()
...@@ -642,27 +603,7 @@ class Emformer(torch.nn.Module): ...@@ -642,27 +603,7 @@ class Emformer(torch.nn.Module):
stride=segment_length, stride=segment_length,
ceil_mode=True, ceil_mode=True,
) )
self.emformer_layers = emformer_layers
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
self.emformer_layers = torch.nn.ModuleList(
[
_EmformerLayer(
input_dim,
num_heads,
ffn_dim,
segment_length,
dropout=dropout,
activation=activation,
left_context_length=left_context_length,
max_memory_size=max_memory_size,
weight_init_gain=weight_init_gains[layer_idx],
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
for layer_idx in range(num_layers)
]
)
self.left_context_length = left_context_length self.left_context_length = left_context_length
self.right_context_length = right_context_length self.right_context_length = right_context_length
self.segment_length = segment_length self.segment_length = segment_length
...@@ -816,7 +757,7 @@ class Emformer(torch.nn.Module): ...@@ -816,7 +757,7 @@ class Emformer(torch.nn.Module):
lengths (torch.Tensor): with shape `(B,)` and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``. number of valid frames for i-th batch element in ``input``.
states (List[List[torch.Tensor]] or None, optional): list of lists of tensors states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing Emformer internal state generated in preceding invocation of ``infer``. (Default: ``None``) representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
Returns: Returns:
(Tensor, Tensor, List[List[Tensor]]): (Tensor, Tensor, List[List[Tensor]]):
...@@ -826,7 +767,7 @@ class Emformer(torch.nn.Module): ...@@ -826,7 +767,7 @@ class Emformer(torch.nn.Module):
output lengths, with shape `(B,)` and i-th element representing output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output frames. number of valid frames for i-th batch element in output frames.
List[List[Tensor]] List[List[Tensor]]
output states; list of lists of tensors representing Emformer internal state output states; list of lists of tensors representing internal state
generated in current invocation of ``infer``. generated in current invocation of ``infer``.
""" """
assert input.size(1) == self.segment_length + self.right_context_length, ( assert input.size(1) == self.segment_length + self.right_context_length, (
...@@ -857,3 +798,79 @@ class Emformer(torch.nn.Module): ...@@ -857,3 +798,79 @@ class Emformer(torch.nn.Module):
output_states.append(output_state) output_states.append(output_state)
return output.permute(1, 0, 2), output_lengths, output_states return output.permute(1, 0, 2), output_lengths, output_states
class Emformer(_EmformerImpl):
r"""Implements the Emformer architecture introduced in
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
[:footcite:`shi2021emformer`].
Args:
input_dim (int): input dimension.
num_heads (int): number of attention heads in each Emformer layer.
ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
num_layers (int): number of Emformer layers to instantiate.
segment_length (int): length of each input segment.
dropout (float, optional): dropout probability. (Default: 0.0)
activation (str, optional): activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
right_context_length (int, optional): length of right context. (Default: 0)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
Examples:
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
>>> lengths = torch.randint(1, 200, (128,)) # batch
>>> output, lengths = emformer(input, lengths)
>>> input = torch.rand(128, 5, 512)
>>> lengths = torch.ones(128) * 5
>>> output, lengths, states = emformer.infer(input, lengths, None)
"""
def __init__(
self,
input_dim: int,
num_heads: int,
ffn_dim: int,
num_layers: int,
segment_length: int,
dropout: float = 0.0,
activation: str = "relu",
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
weight_init_scale_strategy: Optional[str] = "depthwise",
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
emformer_layers = torch.nn.ModuleList(
[
_EmformerLayer(
input_dim,
num_heads,
ffn_dim,
segment_length,
dropout=dropout,
activation=activation,
left_context_length=left_context_length,
max_memory_size=max_memory_size,
weight_init_gain=weight_init_gains[layer_idx],
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
for layer_idx in range(num_layers)
]
)
super().__init__(
emformer_layers,
segment_length,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [ __all__ = [
"conformer_rnnt_base", "conformer_rnnt_base",
"conformer_rnnt_model", "conformer_rnnt_model",
"ConvEmformer",
] ]
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torchaudio.models.emformer import _EmformerAttention from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains
def _get_activation_module(activation: str) -> torch.nn.Module: def _get_activation_module(activation: str) -> torch.nn.Module:
...@@ -70,10 +70,13 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -70,10 +70,13 @@ class _ConvolutionModule(torch.nn.Module):
right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape( right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape(
num_segments * B, self.right_context_length, D num_segments * B, self.right_context_length, D
) )
pad_segments = [
utterance[(idx + 1) * self.segment_length : (idx + 1) * self.segment_length + self.state_size, :, :] pad_segments = [] # [(kernel_size - 1, B, D), ...]
for idx in range(0, num_segments) for seg_idx in range(num_segments):
] # [(kernel_size - 1, B, D), ...] end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
start_idx = end_idx - self.state_size
pad_segments.append(utterance[start_idx:end_idx, :, :])
pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D) pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D)
return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1) return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)
...@@ -100,16 +103,20 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -100,16 +103,20 @@ class _ConvolutionModule(torch.nn.Module):
dtype=input.dtype, dtype=input.dtype,
) # (B, D, T) ) # (B, D, T)
state_x_utterance = torch.cat([state, x_utterance], dim=2) state_x_utterance = torch.cat([state, x_utterance], dim=2)
# (B * num_segments, D, right_context_length + kernel_size - 1)
right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length)
# (T_right_context, B, D)
conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance) conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance)
conv_utterance = conv_utterance.permute(2, 0, 1) conv_utterance = conv_utterance.permute(2, 0, 1)
y = torch.cat([conv_right_context, conv_utterance], dim=0) if self.right_context_length > 0:
# (B * num_segments, D, right_context_length + kernel_size - 1)
right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length)
# (T_right_context, B, D)
conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
y = torch.cat([conv_right_context, conv_utterance], dim=0)
else:
y = conv_utterance
output = self.post_conv(y) + input output = self.post_conv(y) + input
new_state = state_x_utterance[:, :, -self.state_size :] new_state = state_x_utterance[:, :, -self.state_size :]
return output[right_context.size(0) :], output[: right_context.size(0)], new_state return output[right_context.size(0) :], output[: right_context.size(0)], new_state
...@@ -431,3 +438,87 @@ class _ConvEmformerLayer(torch.nn.Module): ...@@ -431,3 +438,87 @@ class _ConvEmformerLayer(torch.nn.Module):
) )
output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state) output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state)
return output_utterance, output_right_context, output_state, next_m return output_utterance, output_right_context, output_state, next_m
class ConvEmformer(_EmformerImpl):
r"""Implements the convolution-augmented streaming transformer architecture introduced in
*Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution*
[:footcite:`9747706`].
Args:
input_dim (int): input dimension.
num_heads (int): number of attention heads in each ConvEmformer layer.
ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network.
num_layers (int): number of ConvEmformer layers to instantiate.
segment_length (int): length of each input segment.
kernel_size (int): size of kernel to use in convolution modules.
dropout (float, optional): dropout probability. (Default: 0.0)
ffn_activation (str, optional): activation function to use in feedforward networks.
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
right_context_length (int, optional): length of right context. (Default: 0)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
conv_activation (str, optional): activation function to use in convolution modules.
Must be one of ("relu", "gelu", "silu"). (Default: "silu")
Examples:
>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
>>> input = torch.rand(10, 200, 80)
>>> lengths = torch.randint(1, 200, (10,))
>>> output, lengths = conv_emformer(input, lengths)
>>> input = torch.rand(4, 20, 80)
>>> lengths = torch.ones(4) * 20
>>> output, lengths, states = conv_emformer.infer(input, lengths, None)
"""
def __init__(
self,
input_dim: int,
num_heads: int,
ffn_dim: int,
num_layers: int,
segment_length: int,
kernel_size: int,
dropout: float = 0.0,
ffn_activation: str = "relu",
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
weight_init_scale_strategy: Optional[str] = "depthwise",
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
conv_activation: str = "silu",
):
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
emformer_layers = torch.nn.ModuleList(
[
_ConvEmformerLayer(
input_dim,
num_heads,
ffn_dim,
segment_length,
kernel_size,
dropout=dropout,
ffn_activation=ffn_activation,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
weight_init_gain=weight_init_gains[layer_idx],
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
conv_activation=conv_activation,
)
for layer_idx in range(num_layers)
]
)
super().__init__(
emformer_layers,
segment_length,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment