Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
......@@ -355,13 +355,14 @@ chunks = next(streamer.stream())
def _display(i):
print("filter_desc:", streamer.get_out_stream_info(i).filter_description)
_, axs = plt.subplots(2, 1)
fig, axs = plt.subplots(2, 1)
waveform = chunks[i][:, 0]
axs[0].plot(waveform)
axs[0].grid(True)
axs[0].set_ylim([-1, 1])
plt.setp(axs[0].get_xticklabels(), visible=False)
axs[1].specgram(waveform, Fs=sample_rate)
fig.tight_layout()
return IPython.display.Audio(chunks[i].T, rate=sample_rate)
......@@ -440,7 +441,6 @@ def _display(i):
axs[j].imshow(chunk[10 * j + 1].permute(1, 2, 0))
axs[j].set_axis_off()
plt.tight_layout()
plt.show(block=False)
######################################################################
......
......@@ -592,7 +592,6 @@ for i, vid in enumerate(vids2):
if i == 0 and j == 0:
ax.set_ylabel("Stream 2")
plt.tight_layout()
plt.show(block=False)
######################################################################
#
......
......@@ -7,10 +7,6 @@ Text-to-Speech with Tacotron2
"""
import IPython
import matplotlib
import matplotlib.pyplot as plt
######################################################################
# Overview
# --------
......@@ -65,8 +61,6 @@ import matplotlib.pyplot as plt
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -75,6 +69,13 @@ print(torchaudio.__version__)
print(device)
######################################################################
#
import IPython
import matplotlib.pyplot as plt
######################################################################
# Text Processing
# ---------------
......@@ -226,13 +227,17 @@ _ = plt.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
# therefor, the process of generating the spectrogram incurs randomness.
#
fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
for i in range(3):
with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
print(spec[0].shape)
ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
plt.show()
def plot():
fig, ax = plt.subplots(3, 1)
for i in range(3):
with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
print(spec[0].shape)
ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
plot()
######################################################################
......@@ -270,11 +275,22 @@ with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
ax2.plot(waveforms[0].cpu().detach())
######################################################################
#
def plot(waveforms, spec, sample_rate):
waveforms = waveforms.cpu().detach()
IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate)
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.plot(waveforms[0])
ax1.set_xlim(0, waveforms.size(-1))
ax1.grid(True)
ax2.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
return IPython.display.Audio(waveforms[0:1], rate=sample_rate)
plot(waveforms, spec, vocoder.sample_rate)
######################################################################
......@@ -300,11 +316,10 @@ with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
ax2.plot(waveforms[0].cpu().detach())
######################################################################
#
IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate)
plot(waveforms, spec, vocoder.sample_rate)
######################################################################
......@@ -339,8 +354,7 @@ waveglow.eval()
with torch.no_grad():
waveforms = waveglow.infer(spec)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
ax2.plot(waveforms[0].cpu().detach())
######################################################################
#
IPython.display.Audio(waveforms[0:1].cpu(), rate=22050)
plot(waveforms, spec, 22050)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment