Commit 55575a53 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add mel spectrogram visualization to Streaming ASR tutorial (#2974)

Summary:
Per the suggestion by nateanl, adding the visualization of feature fed to ASR.

<img width="688" alt="Screen Shot 2023-01-12 at 8 19 59 PM" src="https://user-images.githubusercontent.com/855818/212215190-23be7553-4c04-40d9-944e-3ee2ff69c49b.png">

Pull Request resolved: https://github.com/pytorch/audio/pull/2974

Reviewed By: nateanl

Differential Revision: D42484088

Pulled By: mthrok

fbshipit-source-id: 2c839492869416554eac04aa06cd12078db21bd7
parent a5664ca9
...@@ -48,6 +48,7 @@ print(torchaudio.__version__) ...@@ -48,6 +48,7 @@ print(torchaudio.__version__)
###################################################################### ######################################################################
# #
import IPython import IPython
import matplotlib.pyplot as plt
try: try:
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
...@@ -195,10 +196,28 @@ state, hypothesis = None, None ...@@ -195,10 +196,28 @@ state, hypothesis = None, None
stream_iterator = streamer.stream() stream_iterator = streamer.stream()
def _plot(feats, num_iter, unit=25):
unit_dur = segment_length / sample_rate * unit
num_plots = num_iter // unit + (1 if num_iter % unit else 0)
fig, axes = plt.subplots(num_plots, 1)
t0 = 0
for i, ax in enumerate(axes):
feats_ = feats[i*unit:(i+1)*unit]
t1 = t0 + segment_length / sample_rate * len(feats_)
feats_ = torch.cat([f[2:-2] for f in feats_]) # remove boundary effect and overlap
ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
ax.tick_params(which='both', left=False, labelleft=False)
ax.set_xlim(t0, t0 + unit_dur)
t0 = t1
fig.suptitle("MelSpectrogram Feature")
plt.tight_layout()
@torch.inference_mode() @torch.inference_mode()
def run_inference(num_iter=200): def run_inference(num_iter=100):
global state, hypothesis global state, hypothesis
chunks = [] chunks = []
feats = []
for i, (chunk,) in enumerate(stream_iterator, start=1): for i, (chunk,) in enumerate(stream_iterator, start=1):
segment = cacher(chunk[:, 0]) segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment) features, length = feature_extractor(segment)
...@@ -208,9 +227,12 @@ def run_inference(num_iter=200): ...@@ -208,9 +227,12 @@ def run_inference(num_iter=200):
print(transcript, end="", flush=True) print(transcript, end="", flush=True)
chunks.append(chunk) chunks.append(chunk)
feats.append(features)
if i == num_iter: if i == num_iter:
break break
# Plot the features
_plot(feats, num_iter)
return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=bundle.sample_rate) return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=bundle.sample_rate)
...@@ -249,6 +271,36 @@ run_inference() ...@@ -249,6 +271,36 @@ run_inference()
run_inference() run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
###################################################################### ######################################################################
# #
# Tag: :obj:`torchaudio.io` # Tag: :obj:`torchaudio.io`
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment