Commit cbf1b839 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Apply minor fixes to Emformer implementation (#2252)

Summary:
Noticed some items to clean up in `Emformer`.
- Make `segment_length` a required argument in `_EmformerLayer`.
- Remove unused variables from `_unpack_state` and `_gen_attention_mask`.

These don't affect `Emformer`'s functionality or public API.

Pull Request resolved: https://github.com/pytorch/audio/pull/2252

Reviewed By: carolineechen, mthrok

Differential Revision: D34321430

Pulled By: hwangjeff

fbshipit-source-id: 38a5046f633a3e625352c476ef71c78380ccc597
parent 3184aebc
......@@ -321,11 +321,11 @@ class _EmformerLayer(torch.nn.Module):
input_dim (int): input dimension.
num_heads (int): number of attention heads.
ffn_dim: (int): hidden layer dimension of feedforward network.
segment_length (int): length of each input segment.
dropout (float, optional): dropout probability. (Default: 0.0)
activation (str, optional): activation function to use in feedforward network.
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
segment_length (int, optional): length of each input segment. (Default: 128)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_gain (float or None, optional): scale factor to apply when initializing
attention module parameters. (Default: ``None``)
......@@ -338,10 +338,10 @@ class _EmformerLayer(torch.nn.Module):
input_dim: int,
num_heads: int,
ffn_dim: int,
segment_length: int,
dropout: float = 0.0,
activation: str = "relu",
left_context_length: int = 0,
segment_length: int = 128,
max_memory_size: int = 0,
weight_init_gain: Optional[float] = None,
tanh_on_mem: bool = False,
......@@ -386,9 +386,7 @@ class _EmformerLayer(torch.nn.Module):
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
return [empty_memory, left_context_key, left_context_val, past_length]
def _unpack_state(
self, utterance: torch.Tensor, mems: torch.Tensor, state: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
past_length = state[3][0][0].item()
past_left_context_length = min(self.left_context_length, past_length)
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
......@@ -474,7 +472,7 @@ class _EmformerLayer(torch.nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
if state is None:
state = self._init_state(utterance.size(1), device=utterance.device)
pre_mems, lc_key, lc_val = self._unpack_state(utterance, mems, state)
pre_mems, lc_key, lc_val = self._unpack_state(state)
if self.use_mem:
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
summary = summary[:1]
......@@ -652,10 +650,10 @@ class Emformer(torch.nn.Module):
input_dim,
num_heads,
ffn_dim,
segment_length,
dropout=dropout,
activation=activation,
left_context_length=left_context_length,
segment_length=segment_length,
max_memory_size=max_memory_size,
weight_init_gain=weight_init_gains[layer_idx],
tanh_on_mem=tanh_on_mem,
......@@ -718,7 +716,7 @@ class Emformer(torch.nn.Module):
return col_widths
def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
utterance_length, batch_size, _ = input.shape
utterance_length = input.size(0)
num_segs = math.ceil(utterance_length / self.segment_length)
rc_mask = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment