Commit 4c8fd760 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Clean up Emformer module (#2091)

Summary:
* Removes redundant declaration `right_context_blocks = []`, as flagged by kobenaxie.
* Adds random seed to tests, as flagged by carolineechen in other PRs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2091

Reviewed By: mthrok

Differential Revision: D33340964

Pulled By: hwangjeff

fbshipit-source-id: a9de43e28d1bae7bd4806b280717b0d822bb42fc
parent ece03edc
......@@ -24,6 +24,10 @@ class EmformerTestImpl(TestBaseMixin):
)
return input, lengths
def setUp(self):
super().setUp()
torch.random.manual_seed(29)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting Emformer does not change the behavior of method `forward`."""
input_dim = 128
......
......@@ -669,8 +669,7 @@ class Emformer(torch.nn.Module):
self.max_memory_size = max_memory_size
def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
right_context_blocks = []
T, B, D = input.shape
T = input.shape[0]
num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
right_context_blocks = []
for seg_idx in range(num_segs - 1):
......@@ -765,7 +764,7 @@ class Emformer(torch.nn.Module):
return attention_mask
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training.
r"""Forward pass for training and non-streaming inference.
B: batch size;
T: number of frames;
......@@ -806,7 +805,7 @@ class Emformer(torch.nn.Module):
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Forward pass for inference.
r"""Forward pass for streaming inference.
B: batch size;
T: number of frames;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment