Commit 87d7694d authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Clean up Emformer (#2207)

Summary:
- Make `segment_length` a required argument rather than optional argument to force users to consciously choose input segment lengths for their use cases.
- Clarify expected input shapes in API documentation.
- Adjust `infer` tests to reflect expected usage.
- Add assertion for input shape for `infer`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2207

Reviewed By: mthrok

Differential Revision: D34101205

Pulled By: hwangjeff

fbshipit-source-id: 1d1233d5edee5818d4669b4e47d44559e7ebb304
parent e5d567c9
......@@ -10,7 +10,7 @@ class EmformerTestImpl(TestBaseMixin):
8,
256,
3,
segment_length=4,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
......@@ -49,7 +49,7 @@ class EmformerTestImpl(TestBaseMixin):
r"""Verify that scripting Emformer does not change the behavior of method `infer`."""
input_dim = 128
batch_size = 10
num_frames = 400
num_frames = 5
right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length).eval()
......@@ -57,7 +57,7 @@ class EmformerTestImpl(TestBaseMixin):
ref_state, scripted_state = None, None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
......@@ -83,14 +83,14 @@ class EmformerTestImpl(TestBaseMixin):
r"""Check that method `infer` produces correctly-shaped outputs."""
input_dim = 256
batch_size = 5
num_frames = 200
num_frames = 6
right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval()
state = None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
self.assertEqual((batch_size,), output_lengths.shape)
......@@ -111,10 +111,10 @@ class EmformerTestImpl(TestBaseMixin):
r"""Check that method `infer` returns input `lengths` with right context length subtracted."""
input_dim = 88
batch_size = 13
num_frames = 123
num_frames = 6
right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths)
......@@ -598,12 +598,12 @@ class Emformer(torch.nn.Module):
num_heads (int): number of attention heads in each Emformer layer.
ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
num_layers (int): number of Emformer layers to instantiate.
segment_length (int): length of each input segment.
dropout (float, optional): dropout probability. (Default: 0.0)
activation (str, optional): activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
right_context_length (int, optional): length of right context. (Default: 0)
segment_length (int, optional): length of each input segment. (Default: 128)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
......@@ -611,10 +611,12 @@ class Emformer(torch.nn.Module):
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
Examples:
>>> emformer = Emformer(512, 8, 2048, 20)
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
>>> lengths = torch.randint(1, 200, (128,)) # batch
>>> output = emformer(input, lengths)
>>> input = torch.rand(128, 5, 512)
>>> lengths = torch.ones(128) * 5
>>> output, lengths, states = emformer.infer(input, lengths, None)
"""
......@@ -624,11 +626,11 @@ class Emformer(torch.nn.Module):
num_heads: int,
ffn_dim: int,
num_layers: int,
segment_length: int,
dropout: float = 0.0,
activation: str = "relu",
left_context_length: int = 0,
right_context_length: int = 0,
segment_length: int = 128,
max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False,
......@@ -767,19 +769,19 @@ class Emformer(torch.nn.Module):
r"""Forward pass for training and non-streaming inference.
B: batch size;
T: number of frames;
T: max number of input frames in batch;
D: feature dimension of each frame.
Args:
input (torch.Tensor): utterance frames right-padded with right context frames, with
shape `(B, T, D)`.
shape `(B, T + right_context_length, D)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
number of valid utterance frames for i-th batch element in ``input``.
Returns:
(Tensor, Tensor):
Tensor
output frames, with shape `(B, T - ``right_context_length``, D)`.
output frames, with shape `(B, T, D)`.
Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output frames.
......@@ -808,12 +810,11 @@ class Emformer(torch.nn.Module):
r"""Forward pass for streaming inference.
B: batch size;
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): utterance frames right-padded with right context frames, with
shape `(B, T, D)`.
shape `(B, segment_length + right_context_length, D)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
......@@ -822,7 +823,7 @@ class Emformer(torch.nn.Module):
Returns:
(Tensor, Tensor, List[List[Tensor]]):
Tensor
output frames, with shape `(B, T - ``right_context_length``, D)`.
output frames, with shape `(B, segment_length, D)`.
Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output frames.
......@@ -830,6 +831,11 @@ class Emformer(torch.nn.Module):
output states; list of lists of tensors representing Emformer internal state
generated in current invocation of ``infer``.
"""
assert input.size(1) == self.segment_length + self.right_context_length, (
"Per configured segment_length and right_context_length"
f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
f", but got {input.size(1)}."
)
input = input.permute(1, 0, 2)
right_context_start_idx = input.size(0) - self.right_context_length
right_context = input[right_context_start_idx:]
......
......@@ -189,11 +189,11 @@ class _Transcriber(torch.nn.Module):
transformer_num_heads,
transformer_ffn_dim,
transformer_num_layers,
segment_length // time_reduction_stride,
dropout=transformer_dropout,
activation=transformer_activation,
left_context_length=transformer_left_context_length,
right_context_length=right_context_length // time_reduction_stride,
segment_length=segment_length // time_reduction_stride,
max_memory_size=transformer_max_memory_size,
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
tanh_on_mem=transformer_tanh_on_mem,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment