"src/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "a7c5007c238830238f68aa88bc37cc5e424fa82b"
Commit 8e3c6144 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update context building to not delay the inference (#2213)

Summary:
Updating the context cacher so that fetched audio chunk is used for inference immediately.

https://github.com/pytorch/audio/pull/2202#discussion_r802838174

Pull Request resolved: https://github.com/pytorch/audio/pull/2213

Reviewed By: hwangjeff

Differential Revision: D34235230

Pulled By: mthrok

fbshipit-source-id: 6e4aee7cca34ca81e40c0cb13497182f20f7f04e
parent 411b5dcf
...@@ -81,12 +81,12 @@ token_processor = bundle.get_token_processor() ...@@ -81,12 +81,12 @@ token_processor = bundle.get_token_processor()
# #
sample_rate = bundle.sample_rate sample_rate = bundle.sample_rate
frames_per_chunk = bundle.segment_length * bundle.hop_length segment_length = bundle.segment_length * bundle.hop_length
right_context_size = bundle.right_context_length * bundle.hop_length context_length = bundle.right_context_length * bundle.hop_length
print(f"Sample rate: {sample_rate}") print(f"Sample rate: {sample_rate}")
print(f"Main segment: {frames_per_chunk} frames ({frames_per_chunk / sample_rate} seconds)") print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {right_context_size} frames ({right_context_size / sample_rate} seconds)") print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")
###################################################################### ######################################################################
# 4. Configure the audio stream # 4. Configure the audio stream
...@@ -109,7 +109,7 @@ print(f"Right context: {right_context_size} frames ({right_context_size / sample ...@@ -109,7 +109,7 @@ print(f"Right context: {right_context_size} frames ({right_context_size / sample
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3" src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"
streamer = Streamer(src) streamer = Streamer(src)
streamer.add_basic_audio_stream(frames_per_chunk=frames_per_chunk, sample_rate=bundle.sample_rate) streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)
print(streamer.get_src_stream_info(0)) print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0)) print(streamer.get_out_stream_info(0))
...@@ -125,18 +125,21 @@ class ContextCacher: ...@@ -125,18 +125,21 @@ class ContextCacher:
"""Cache the previous chunk and combine it with the new chunk """Cache the previous chunk and combine it with the new chunk
Args: Args:
chunk (torch.Tensor): Initial chunk segment_length (int): The size of main segment.
right_context_size (int): The size of right context. If the incoming segment is shorter, then the segment is padded.
context_length (int): The size of the context, cached and appended.
""" """
def __init__(self, chunk: torch.Tensor, right_context_size: int): def __init__(self, segment_length: int, context_length: int):
self.chunk = chunk self.segment_length = segment_length
self.right_context_size = right_context_size self.context_length = context_length
self.context = torch.zeros([context_length])
def __call__(self, chunk: torch.Tensor): def __call__(self, chunk: torch.Tensor):
right_context = chunk[: self.right_context_size, :] if chunk.size(0) < self.segment_length:
chunk_with_context = torch.cat((self.chunk, right_context)) chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
self.chunk = chunk chunk_with_context = torch.cat((self.context, chunk))
self.context = chunk[-self.context_length :]
return chunk_with_context return chunk_with_context
...@@ -151,8 +154,7 @@ class ContextCacher: ...@@ -151,8 +154,7 @@ class ContextCacher:
# decoding state between inference calls. # decoding state between inference calls.
# #
stream_iterator = streamer.stream() cacher = ContextCacher(segment_length, context_length)
cacher = ContextCacher(next(stream_iterator)[0], right_context_size)
state, hypothesis = None, None state, hypothesis = None, None
...@@ -169,8 +171,8 @@ state, hypothesis = None, None ...@@ -169,8 +171,8 @@ state, hypothesis = None, None
def run_inference(num_iter=200): def run_inference(num_iter=200):
global state, hypothesis global state, hypothesis
chunks = [] chunks = []
for i, (chunk,) in enumerate(stream_iterator, start=1): for i, (chunk,) in enumerate(streamer.stream(), start=1):
segment = cacher(chunk).T[0] segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment) features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment