Commit 4c8fd760 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Clean up Emformer module (#2091)

Summary:
* Removes redundant declaration `right_context_blocks = []`, as flagged by kobenaxie.
* Adds random seed to tests, as flagged by carolineechen in other PRs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2091

Reviewed By: mthrok

Differential Revision: D33340964

Pulled By: hwangjeff

fbshipit-source-id: a9de43e28d1bae7bd4806b280717b0d822bb42fc
parent ece03edc
...@@ -24,6 +24,10 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -24,6 +24,10 @@ class EmformerTestImpl(TestBaseMixin):
) )
return input, lengths return input, lengths
def setUp(self):
super().setUp()
torch.random.manual_seed(29)
def test_torchscript_consistency_forward(self): def test_torchscript_consistency_forward(self):
r"""Verify that scripting Emformer does not change the behavior of method `forward`.""" r"""Verify that scripting Emformer does not change the behavior of method `forward`."""
input_dim = 128 input_dim = 128
......
...@@ -669,8 +669,7 @@ class Emformer(torch.nn.Module): ...@@ -669,8 +669,7 @@ class Emformer(torch.nn.Module):
self.max_memory_size = max_memory_size self.max_memory_size = max_memory_size
def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor: def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
right_context_blocks = [] T = input.shape[0]
T, B, D = input.shape
num_segs = math.ceil((T - self.right_context_length) / self.segment_length) num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
right_context_blocks = [] right_context_blocks = []
for seg_idx in range(num_segs - 1): for seg_idx in range(num_segs - 1):
...@@ -765,7 +764,7 @@ class Emformer(torch.nn.Module): ...@@ -765,7 +764,7 @@ class Emformer(torch.nn.Module):
return attention_mask return attention_mask
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training. r"""Forward pass for training and non-streaming inference.
B: batch size; B: batch size;
T: number of frames; T: number of frames;
...@@ -806,7 +805,7 @@ class Emformer(torch.nn.Module): ...@@ -806,7 +805,7 @@ class Emformer(torch.nn.Module):
lengths: torch.Tensor, lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None, states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Forward pass for inference. r"""Forward pass for streaming inference.
B: batch size; B: batch size;
T: number of frames; T: number of frames;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment