Commit 55575a53 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add mel spectrogram visualization to Streaming ASR tutorial (#2974)

Summary:
Per the suggestion by nateanl, adding the visualization of feature fed to ASR.

<img width="688" alt="Screen Shot 2023-01-12 at 8 19 59 PM" src="https://user-images.githubusercontent.com/855818/212215190-23be7553-4c04-40d9-944e-3ee2ff69c49b.png">

Pull Request resolved: https://github.com/pytorch/audio/pull/2974

Reviewed By: nateanl

Differential Revision: D42484088

Pulled By: mthrok

fbshipit-source-id: 2c839492869416554eac04aa06cd12078db21bd7
parent a5664ca9
......@@ -48,6 +48,7 @@ print(torchaudio.__version__)
######################################################################
#
import IPython
import matplotlib.pyplot as plt
try:
from torchaudio.io import StreamReader
......@@ -195,10 +196,28 @@ state, hypothesis = None, None
stream_iterator = streamer.stream()
def _plot(feats, num_iter, unit=25):
unit_dur = segment_length / sample_rate * unit
num_plots = num_iter // unit + (1 if num_iter % unit else 0)
fig, axes = plt.subplots(num_plots, 1)
t0 = 0
for i, ax in enumerate(axes):
feats_ = feats[i*unit:(i+1)*unit]
t1 = t0 + segment_length / sample_rate * len(feats_)
feats_ = torch.cat([f[2:-2] for f in feats_]) # remove boundary effect and overlap
ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
ax.tick_params(which='both', left=False, labelleft=False)
ax.set_xlim(t0, t0 + unit_dur)
t0 = t1
fig.suptitle("MelSpectrogram Feature")
plt.tight_layout()
@torch.inference_mode()
def run_inference(num_iter=200):
def run_inference(num_iter=100):
global state, hypothesis
chunks = []
feats = []
for i, (chunk,) in enumerate(stream_iterator, start=1):
segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment)
......@@ -208,9 +227,12 @@ def run_inference(num_iter=200):
print(transcript, end="", flush=True)
chunks.append(chunk)
feats.append(features)
if i == num_iter:
break
# Plot the features
_plot(feats, num_iter)
return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=bundle.sample_rate)
......@@ -249,6 +271,36 @@ run_inference()
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
run_inference()
######################################################################
#
# Tag: :obj:`torchaudio.io`
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment