Commit 8e3c6144 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update context building to not delay the inference (#2213)

Summary:
Updating the context cacher so that fetched audio chunk is used for inference immediately.

https://github.com/pytorch/audio/pull/2202#discussion_r802838174

Pull Request resolved: https://github.com/pytorch/audio/pull/2213

Reviewed By: hwangjeff

Differential Revision: D34235230

Pulled By: mthrok

fbshipit-source-id: 6e4aee7cca34ca81e40c0cb13497182f20f7f04e
parent 411b5dcf
......@@ -81,12 +81,12 @@ token_processor = bundle.get_token_processor()
#
sample_rate = bundle.sample_rate
frames_per_chunk = bundle.segment_length * bundle.hop_length
right_context_size = bundle.right_context_length * bundle.hop_length
segment_length = bundle.segment_length * bundle.hop_length
context_length = bundle.right_context_length * bundle.hop_length
print(f"Sample rate: {sample_rate}")
print(f"Main segment: {frames_per_chunk} frames ({frames_per_chunk / sample_rate} seconds)")
print(f"Right context: {right_context_size} frames ({right_context_size / sample_rate} seconds)")
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")
######################################################################
# 4. Configure the audio stream
......@@ -109,7 +109,7 @@ print(f"Right context: {right_context_size} frames ({right_context_size / sample
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"
streamer = Streamer(src)
streamer.add_basic_audio_stream(frames_per_chunk=frames_per_chunk, sample_rate=bundle.sample_rate)
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)
print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))
......@@ -125,18 +125,21 @@ class ContextCacher:
"""Cache the previous chunk and combine it with the new chunk
Args:
chunk (torch.Tensor): Initial chunk
right_context_size (int): The size of right context.
segment_length (int): The size of main segment.
If the incoming segment is shorter, then the segment is padded.
context_length (int): The size of the context, cached and appended.
"""
def __init__(self, chunk: torch.Tensor, right_context_size: int):
self.chunk = chunk
self.right_context_size = right_context_size
def __init__(self, segment_length: int, context_length: int):
self.segment_length = segment_length
self.context_length = context_length
self.context = torch.zeros([context_length])
def __call__(self, chunk: torch.Tensor):
right_context = chunk[: self.right_context_size, :]
chunk_with_context = torch.cat((self.chunk, right_context))
self.chunk = chunk
if chunk.size(0) < self.segment_length:
chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
chunk_with_context = torch.cat((self.context, chunk))
self.context = chunk[-self.context_length :]
return chunk_with_context
......@@ -151,8 +154,7 @@ class ContextCacher:
# decoding state between inference calls.
#
stream_iterator = streamer.stream()
cacher = ContextCacher(next(stream_iterator)[0], right_context_size)
cacher = ContextCacher(segment_length, context_length)
state, hypothesis = None, None
......@@ -169,8 +171,8 @@ state, hypothesis = None, None
def run_inference(num_iter=200):
global state, hypothesis
chunks = []
for i, (chunk,) in enumerate(stream_iterator, start=1):
segment = cacher(chunk).T[0]
for i, (chunk,) in enumerate(streamer.stream(), start=1):
segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment