Unverified Commit 0ac196c6 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Fix style checks in examples/tutorials (#2006)

parent e6dd7fd3
......@@ -22,16 +22,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import math
import os
......@@ -46,17 +47,18 @@ _SAMPLE_DIR = "_assets"
SAMPLE_WAV_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.wav"
SAMPLE_WAV_PATH = os.path.join(_SAMPLE_DIR, "steam.wav")
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav"
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav" # noqa: E501
SAMPLE_RIR_PATH = os.path.join(_SAMPLE_DIR, "rir.wav")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav" # noqa: E501
SAMPLE_NOISE_PATH = os.path.join(_SAMPLE_DIR, "bg.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
......@@ -65,28 +67,33 @@ def _fetch_data():
(SAMPLE_NOISE_URL, SAMPLE_NOISE_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
effects = [["remix", "1"]]
if resample:
effects.extend([
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
......@@ -100,7 +107,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
......@@ -108,6 +115,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
figure.suptitle(title)
plt.show(block=False)
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
......@@ -125,11 +133,11 @@ def print_stats(waveform, sample_rate=None, src=None):
print(waveform)
print()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
......@@ -137,12 +145,13 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
......@@ -154,15 +163,17 @@ def play_audio(waveform, sample_rate):
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def get_rir_sample(*, resample=None, processed=False):
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed:
return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
return rir, sample_rate
def get_noise_sample(*, resample=None):
return _get_sample(SAMPLE_NOISE_PATH, resample=resample)
......@@ -218,10 +229,11 @@ effects = [
# Apply effects
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
waveform1, sample_rate1, effects)
waveform1, sample_rate1, effects
)
plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-.1, 3.2))
plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-.1, 3.2))
plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-0.1, 3.2))
plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-0.1, 3.2))
print_stats(waveform1, sample_rate=sample_rate1, src="Original")
print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied")
......@@ -268,7 +280,7 @@ play_audio(rir_raw, sample_rate)
# the signal power, then flip along the time axis.
#
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
......@@ -281,7 +293,7 @@ plot_waveform(rir, sample_rate, title="Room Impulse Response", ylim=None)
speech, _ = get_speech_sample(resample=sample_rate)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
augmented = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_waveform(speech, sample_rate, title="Original", ylim=None)
......@@ -312,7 +324,7 @@ play_audio(augmented, sample_rate)
sample_rate = 8000
speech, _ = get_speech_sample(resample=sample_rate)
noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]]
noise = noise[:, : speech.shape[1]]
plot_waveform(noise, sample_rate, title="Background noise")
plot_specgram(noise, sample_rate, title="Background noise")
......@@ -346,7 +358,7 @@ plot_specgram(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate)
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis"),
......@@ -373,7 +385,7 @@ play_audio(speech, sample_rate)
# Apply RIR
rir, _ = get_rir_sample(resample=sample_rate, processed=True)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_specgram(speech, sample_rate, title="RIR Applied")
......@@ -384,7 +396,7 @@ play_audio(speech, sample_rate)
# the noise contains the acoustic feature of the environment. Therefore, we add
# the noise after RIR application.
noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]]
noise = noise[:, : speech.shape[1]]
snr_db = 8
scale = math.exp(snr_db / 10) * noise.norm(p=2) / speech.norm(p=2)
......@@ -399,7 +411,14 @@ speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
[
"compand",
"0.02,0.05",
"-60,-60,-30,-10,-20,-8,-5,-8,-2,-8",
"-8",
"-7",
"0.05",
],
["rate", "8000"],
],
)
......
......@@ -23,14 +23,14 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import multiprocessing
import os
......@@ -42,19 +42,21 @@ _SAMPLE_DIR = "_assets"
YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def _download_yesno():
if os.path.exists(os.path.join(YESNO_DATASET_PATH, "waves_yesno.tar.gz")):
return
torchaudio.datasets.YESNO(root=YESNO_DATASET_PATH, download=True)
YESNO_DOWNLOAD_PROCESS = multiprocessing.Process(target=_download_yesno)
YESNO_DOWNLOAD_PROCESS.start()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
......@@ -62,12 +64,13 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
......@@ -79,10 +82,12 @@ def play_audio(waveform, sample_rate):
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
######################################################################
# Here, we show how to use the ``YESNO`` dataset.
#
YESNO_DOWNLOAD_PROCESS.join()
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
......
......@@ -20,16 +20,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import os
import requests
......@@ -40,40 +41,45 @@ import matplotlib.pyplot as plt
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
effects = [["remix", "1"]]
if resample:
effects.extend([
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def get_spectrogram(
n_fft = 400,
win_len = None,
hop_len = None,
power = 2.0,
n_fft=400,
win_len=None,
hop_len=None,
power=2.0,
):
waveform, _ = get_speech_sample()
spectrogram = T.Spectrogram(
......@@ -86,17 +92,19 @@ def get_spectrogram(
)
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Spectrogram (db)')
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel('frame')
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
######################################################################
# SpecAugment
# -----------
......@@ -111,18 +119,23 @@ def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=No
# ~~~~~~~~~~~
#
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect='equal', xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
######################################################################
# TimeMasking
......
......@@ -38,16 +38,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import os
import requests
......@@ -59,7 +60,7 @@ from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
......@@ -70,25 +71,29 @@ def _fetch_data():
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
effects = [["remix", "1"]]
if resample:
effects.extend([
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
......@@ -106,17 +111,19 @@ def print_stats(waveform, sample_rate=None, src=None):
print(waveform)
print()
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Spectrogram (db)')
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel('frame')
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
......@@ -130,7 +137,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
......@@ -138,6 +145,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
......@@ -149,14 +157,16 @@ def play_audio(waveform, sample_rate):
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_mel_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Filter bank')
axs.imshow(fbank, aspect='auto')
axs.set_ylabel('frequency bin')
axs.set_xlabel('mel bin')
axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
plt.show(block=False)
def plot_pitch(waveform, sample_rate, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
......@@ -164,16 +174,16 @@ def plot_pitch(waveform, sample_rate, pitch):
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3)
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln2 = axis2.plot(
time_axis, pitch[0], linewidth=2, label='Pitch', color='green')
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
......@@ -181,22 +191,24 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3)
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label='Pitch', color='green')
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(
time_axis, nfcc[0], linewidth=2, label='NFCC', color='blue', linestyle='--')
time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--"
)
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
######################################################################
# Spectrogram
# -----------
......@@ -206,7 +218,6 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
#
waveform, sample_rate = get_speech_sample()
n_fft = 1024
......@@ -226,7 +237,7 @@ spectrogram = T.Spectrogram(
spec = spectrogram(waveform)
print_stats(spec)
plot_spectrogram(spec[0], title='torchaudio')
plot_spectrogram(spec[0], title="torchaudio")
######################################################################
# GriffinLim
......@@ -280,10 +291,10 @@ sample_rate = 6000
mel_filters = F.melscale_fbanks(
int(n_fft // 2 + 1),
n_mels=n_mels,
f_min=0.,
f_max=sample_rate/2.,
f_min=0.0,
f_max=sample_rate / 2.0,
sample_rate=sample_rate,
norm='slaney'
norm="slaney",
)
plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
......@@ -300,16 +311,16 @@ mel_filters_librosa = librosa.filters.mel(
sample_rate,
n_fft,
n_mels=n_mels,
fmin=0.,
fmax=sample_rate/2.,
norm='slaney',
fmin=0.0,
fmax=sample_rate / 2.0,
norm="slaney",
htk=True,
).T
plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
mse = torch.square(mel_filters - mel_filters_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# MelSpectrogram
......@@ -336,15 +347,14 @@ mel_spectrogram = T.MelSpectrogram(
center=True,
pad_mode="reflect",
power=2.0,
norm='slaney',
norm="slaney",
onesided=True,
n_mels=n_mels,
mel_scale="htk",
)
melspec = mel_spectrogram(waveform)
plot_spectrogram(
melspec[0], title="MelSpectrogram - torchaudio", ylabel='mel freq')
plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq")
######################################################################
# Comparison against librosa
......@@ -365,14 +375,13 @@ melspec_librosa = librosa.feature.melspectrogram(
pad_mode="reflect",
power=2.0,
n_mels=n_mels,
norm='slaney',
norm="slaney",
htk=True,
)
plot_spectrogram(
melspec_librosa, title="MelSpectrogram - librosa", ylabel='mel freq')
plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq")
mse = torch.square(melspec - melspec_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# MFCC
......@@ -391,11 +400,11 @@ mfcc_transform = T.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': n_fft,
'n_mels': n_mels,
'hop_length': hop_length,
'mel_scale': 'htk',
}
"n_fft": n_fft,
"n_mels": n_mels,
"hop_length": hop_length,
"mel_scale": "htk",
},
)
mfcc = mfcc_transform(waveform)
......@@ -409,18 +418,27 @@ plot_spectrogram(mfcc[0])
melspec = librosa.feature.melspectrogram(
y=waveform.numpy()[0], sr=sample_rate, n_fft=n_fft,
win_length=win_length, hop_length=hop_length,
n_mels=n_mels, htk=True, norm=None)
y=waveform.numpy()[0],
sr=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
)
mfcc_librosa = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec),
n_mfcc=n_mfcc, dct_type=2, norm='ortho')
n_mfcc=n_mfcc,
dct_type=2,
norm="ortho",
)
plot_spectrogram(mfcc_librosa)
mse = torch.square(mfcc - mfcc_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# Pitch
......
......@@ -21,12 +21,13 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
import io
......@@ -51,7 +52,7 @@ SAMPLE_MP3_PATH = os.path.join(_SAMPLE_DIR, "steam.mp3")
SAMPLE_GSM_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.gsm"
SAMPLE_GSM_PATH = os.path.join(_SAMPLE_DIR, "steam.gsm")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_TAR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit.tar.gz"
......@@ -61,6 +62,7 @@ SAMPLE_TAR_ITEM = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp
S3_BUCKET = "pytorch-tutorial-assets"
S3_KEY = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
def _fetch_data():
os.makedirs(_SAMPLE_DIR, exist_ok=True)
uri = [
......@@ -71,11 +73,13 @@ def _fetch_data():
(SAMPLE_TAR_URL, SAMPLE_TAR_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
......@@ -93,6 +97,7 @@ def print_stats(waveform, sample_rate=None, src=None):
print(waveform)
print()
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
......@@ -106,7 +111,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
......@@ -114,11 +119,11 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None)
figure.suptitle(title)
plt.show(block=False)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
......@@ -126,12 +131,13 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
......@@ -143,20 +149,23 @@ def play_audio(waveform, sample_rate):
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
effects = [["remix", "1"]]
if resample:
effects.extend([
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def inspect_file(path):
print("-" * 10)
print("Source:", path)
......@@ -164,14 +173,16 @@ def inspect_file(path):
print(f" - File size: {os.path.getsize(path)} bytes")
print(f" - {torchaudio.info(path)}")
######################################################################
# Quering audio metadata
# ----------------------
# Querying audio metadata
# -----------------------
#
# Function ``torchaudio.info`` fetches audio metadata. You can provide
# a path-like object or file-like object.
#
metadata = torchaudio.info(SAMPLE_WAV_PATH)
print(metadata)
......@@ -295,15 +306,15 @@ with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
plot_specgram(waveform, sample_rate, title="HTTP datasource")
# Load audio from tar file
with tarfile.open(SAMPLE_TAR_PATH, mode='r') as tarfile_:
with tarfile.open(SAMPLE_TAR_PATH, mode="r") as tarfile_:
fileobj = tarfile_.extractfile(SAMPLE_TAR_ITEM)
waveform, sample_rate = torchaudio.load(fileobj)
plot_specgram(waveform, sample_rate, title="TAR file")
# Load audio from S3
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
response = client.get_object(Bucket=S3_BUCKET, Key=S3_KEY)
waveform, sample_rate = torchaudio.load(response['Body'])
waveform, sample_rate = torchaudio.load(response["Body"])
plot_specgram(waveform, sample_rate, title="From S3")
......@@ -337,13 +348,14 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds
print("Fetching all the data...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset:frame_offset+num_frames]
waveform1 = waveform1[:, frame_offset: frame_offset + num_frames]
print(f" - Fetched {response.raw.tell()} bytes")
print("Fetching until the requested frames are available...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load(
response.raw, frame_offset=frame_offset, num_frames=num_frames)
response.raw, frame_offset=frame_offset, num_frames=num_frames
)
print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="")
......@@ -389,9 +401,7 @@ inspect_file(path)
# Save as 16-bit signed integer Linear PCM
# The resulting file occupies half the storage but loses precision
path = f"{_SAMPLE_DIR}/save_example_PCM_S16.wav"
torchaudio.save(
path, waveform, sample_rate,
encoding="PCM_S", bits_per_sample=16)
torchaudio.save(path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
inspect_file(path)
......@@ -435,4 +445,3 @@ torchaudio.save(buffer_, waveform, sample_rate, format="wav")
buffer_.seek(0)
print(buffer_.read(16))
......@@ -24,14 +24,14 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import math
import time
......@@ -46,7 +46,7 @@ DEFAULT_OFFSET = 201
SWEEP_MAX_SAMPLE_RATE = 48000
DEFAULT_LOWPASS_FILTER_WIDTH = 6
DEFAULT_ROLLOFF = 0.99
DEFAULT_RESAMPLING_METHOD = 'sinc_interpolation'
DEFAULT_RESAMPLING_METHOD = "sinc_interpolation"
def _get_log_freq(sample_rate, max_sweep_rate, offset):
......@@ -55,15 +55,18 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset):
offset is used to avoid negative infinity `log(offset + x)`.
"""
half = sample_rate // 2
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
return (
torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
)
def _get_inverse_log_freq(freq, sample_rate, offset):
"""Find the time where the given frequency is given by _get_log_freq"""
half = sample_rate // 2
return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))
def _get_freq_ticks(sample_rate, offset, f_max):
# Given the original sample rate used for generating the sweep,
# find the x-axis value where the log-scale major frequency values fall in
......@@ -80,6 +83,7 @@ def _get_freq_ticks(sample_rate, offset, f_max):
freq.append(f_max)
return time, freq
def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
max_sweep_rate = sample_rate
freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
......@@ -88,9 +92,16 @@ def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
signal = torch.sin(cummulative).unsqueeze(dim=0)
return signal
def plot_sweep(waveform, sample_rate, title, max_sweep_rate=SWEEP_MAX_SAMPLE_RATE, offset=DEFAULT_OFFSET):
def plot_sweep(
waveform,
sample_rate,
title,
max_sweep_rate=SWEEP_MAX_SAMPLE_RATE,
offset=DEFAULT_OFFSET,
):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
y_ticks = [1000, 5000, 10000, 20000, sample_rate//2]
y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
......@@ -100,13 +111,14 @@ def plot_sweep(waveform, sample_rate, title, max_sweep_rate=SWEEP_MAX_SAMPLE_RAT
axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
axis.set_xlabel('Original Signal Frequency (Hz, log scale)')
axis.set_ylabel('Waveform Frequency (Hz)')
axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
axis.set_ylabel("Waveform Frequency (Hz)")
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f'{title} (sample rate: {sample_rate} Hz)')
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.show(block=True)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
......@@ -118,11 +130,11 @@ def play_audio(waveform, sample_rate):
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
......@@ -130,12 +142,13 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def benchmark_resample(
method,
waveform,
......@@ -146,18 +159,30 @@ def benchmark_resample(
resampling_method=DEFAULT_RESAMPLING_METHOD,
beta=None,
librosa_type=None,
iters=5
iters=5,
):
if method == "functional":
begin = time.time()
for _ in range(iters):
F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff, resampling_method=resampling_method)
F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
)
elapsed = time.time() - begin
return elapsed / iters
elif method == "transforms":
resampler = T.Resample(sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff, resampling_method=resampling_method, dtype=waveform.dtype)
resampler = T.Resample(
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
dtype=waveform.dtype,
)
begin = time.time()
for _ in range(iters):
resampler(waveform)
......@@ -167,10 +192,13 @@ def benchmark_resample(
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(waveform_np, sample_rate, resample_rate, res_type=librosa_type)
librosa.resample(
waveform_np, sample_rate, resample_rate, res_type=librosa_type
)
elapsed = time.time() - begin
return elapsed / iters
######################################################################
# To resample an audio waveform from one freqeuncy to another, you can use
# ``transforms.Resample`` or ``functional.resample``.
......@@ -202,6 +230,7 @@ def benchmark_resample(
# plotted waveform, and color intensity the amplitude.
#
sample_rate = 48000
resample_rate = 32000
......@@ -235,10 +264,14 @@ play_audio(waveform, sample_rate)
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=6
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=128
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
......@@ -282,10 +315,14 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation"
)
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="kaiser_window"
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
......@@ -301,7 +338,7 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
sample_rate = 48000
resample_rate = 32000
### kaiser_best
# kaiser_best
resampled_waveform = F.resample(
waveform,
sample_rate,
......@@ -309,18 +346,23 @@ resampled_waveform = F.resample(
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492
beta=14.769656459379492,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_best')).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)
### kaiser_fast
# kaiser_fast
resampled_waveform = F.resample(
waveform,
sample_rate,
......@@ -328,13 +370,20 @@ resampled_waveform = F.resample(
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386
beta=8.555504641634386,
)
plot_specgram(
resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)"
)
plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_fast')).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)
......@@ -377,19 +426,29 @@ for label in configs:
waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append(f"sinc (width 64)")
rows.append("sinc (width 64)")
# sinc 6 zero-crossings
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append(f"sinc (width 16)")
rows.append("sinc (width 16)")
# kaiser best
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best"
)
f_time = benchmark_resample(
"functional",
waveform,
......@@ -398,7 +457,8 @@ for label in configs:
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492)
beta=14.769656459379492,
)
t_time = benchmark_resample(
"transforms",
waveform,
......@@ -407,12 +467,15 @@ for label in configs:
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492)
beta=14.769656459379492,
)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append(f"kaiser_best")
rows.append("kaiser_best")
# kaiser fast
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast"
)
f_time = benchmark_resample(
"functional",
waveform,
......@@ -421,7 +484,8 @@ for label in configs:
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386)
beta=8.555504641634386,
)
t_time = benchmark_resample(
"transforms",
waveform,
......@@ -430,12 +494,13 @@ for label in configs:
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386)
beta=8.555504641634386,
)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append(f"kaiser_fast")
rows.append("kaiser_fast")
df = pd.DataFrame(times,
columns=["librosa", "functional", "transforms"],
index=rows)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"],df.columns])
df = pd.DataFrame(
times, columns=["librosa", "functional", "transforms"], index=rows
)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
display(df.round(2))
......@@ -47,21 +47,21 @@ import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = 'https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav'
SPEECH_FILE = '_assets/speech.wav'
SPEECH_URL = "https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True)
with open(SPEECH_FILE, 'wb') as file:
os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
......@@ -142,12 +142,13 @@ plt.show()
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
#
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
......@@ -155,10 +156,10 @@ def get_trellis(emission, tokens, blank_id=0):
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.full((num_frame+1, num_tokens+1), -float('inf'))
trellis = torch.full((num_frame + 1, num_tokens + 1), -float("inf"))
trellis[:, 0] = 0
for t in range(num_frame):
trellis[t+1, 1:] = torch.maximum(
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
......@@ -166,12 +167,13 @@ def get_trellis(emission, tokens, blank_id=0):
)
return trellis
trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis[1:, 1:].T, origin='lower')
plt.imshow(trellis[1:, 1:].T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.colorbar()
plt.show()
......@@ -203,6 +205,7 @@ plt.show()
# emission matrix.
#
@dataclass
class Point:
token_index: int
......@@ -214,9 +217,9 @@ def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When refering to time frame index `T` in trellis,
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when refering to token index `J` in trellis,
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
......@@ -227,14 +230,14 @@ def backtrack(trellis, emission, tokens, blank_id=0):
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t-1, j] + emission[t-1, blank_id]
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t-1, j-1] + emission[t-1, tokens[j-1]]
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability.
prob = emission[t-1, tokens[j-1] if changed > stayed else 0].exp().item()
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j-1, t-1, prob))
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token
if changed > stayed:
......@@ -242,21 +245,24 @@ def backtrack(trellis, emission, tokens, blank_id=0):
if j == 0:
break
else:
raise ValueError('Failed to align')
raise ValueError("Failed to align")
return path[::-1]
path = backtrack(trellis, emission, tokens)
print(path)
################################################################################
# Visualization
################################################################################
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float('nan')
plt.imshow(trellis_with_path[1:, 1:].T, origin='lower')
for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path[1:, 1:].T, origin="lower")
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
......@@ -270,6 +276,7 @@ plt.show()
# probability for the merged segments.
#
# Merge the labels
@dataclass
class Segment:
......@@ -285,6 +292,7 @@ class Segment:
def length(self):
return self.end - self.start
def merge_repeats(path):
i1, i2 = 0, 0
segments = []
......@@ -292,14 +300,23 @@ def merge_repeats(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(Segment(transcript[path[i1].token_index], path[i1].time_index, path[i2-1].time_index + 1, score))
segments.append(
Segment(
transcript[path[i1].token_index],
path[i1].time_index,
path[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
segments = merge_repeats(path)
for seg in segments:
print(seg)
################################################################################
# Visualization
################################################################################
......@@ -307,41 +324,42 @@ def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin='lower')
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([])
for i, seg in enumerate(segments):
if seg.label != '|':
ax1.annotate(seg.label, (seg.start + .7, i + 0.3), weight='bold')
ax1.annotate(f'{seg.score:.2f}', (seg.start - .3, i + 4.3))
if seg.label != "|":
ax1.annotate(seg.label, (seg.start + 0.7, i + 0.3), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start - 0.3, i + 4.3))
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != '|':
xs.append((seg.end + seg.start) / 2 + .4)
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + .8, -0.07), weight='bold')
ax2.bar(xs, hs, width=ws, color='gray', alpha=0.5, edgecolor='black')
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
for p in path:
label = transcript[p.token_index]
if label != '|':
if label != "|":
xs.append(p.time_index + 1)
hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color='black')
ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(-0.1, 1.1)
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
......@@ -357,38 +375,44 @@ plt.show()
#
# Merge words
def merge_words(segments, separator='|'):
def merge_words(segments, separator="|"):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = ''.join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2-1].end, score))
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(
seg.length for seg in segs
)
words.append(
Segment(word, segments[i1].start, segments[i2 - 1].end, score)
)
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
word_segments = merge_words(segments)
for word in word_segments:
print(word)
################################################################################
# Visualization
################################################################################
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != '|':
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan')
if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower')
ax1.imshow(trellis_with_path[1:, 1:].T, origin="lower")
ax1.set_xticks([])
ax1.set_yticks([])
......@@ -397,9 +421,9 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax1.axvline(word.end - 0.5)
for i, seg in enumerate(segments):
if seg.label != '|':
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i + 0.3))
ax1.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8)
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 4), fontsize=8)
# The original waveform
ratio = waveform.size(0) / (trellis.size(0) - 1)
......@@ -407,22 +431,29 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color='red')
ax2.annotate(f'{word.score:.2f}', (x0, 0.8))
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8))
for seg in segments:
if seg.label != '|':
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9))
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.set_xlabel('time [second]')
ax2.set_xlabel("time [second]")
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
plot_alignments(trellis, segments, word_segments, waveform[0],)
plot_alignments(
trellis,
segments,
word_segments,
waveform[0],
)
plt.show()
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
......@@ -433,9 +464,12 @@ def display_segment(i):
x1 = int(ratio * word.end)
filename = f"_assets/{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
print(
f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec"
)
return IPython.display.Audio(filename)
######################################################################
#
......
......@@ -40,7 +40,7 @@ MVDR with torchaudio
# which was generated with;
#
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0)
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # noqa: E501
#
import os
......@@ -50,24 +50,24 @@ import torchaudio
import IPython.display as ipd
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
filenames = [
'mix.wav',
'reverb_clean.wav',
'clean.wav',
"mix.wav",
"reverb_clean.wav",
"clean.wav",
]
base_url = 'https://download.pytorch.org/torchaudio/tutorial-assets/mvdr'
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
for filename in filenames:
os.makedirs('_assets', exist_ok=True)
os.makedirs("_assets", exist_ok=True)
if not os.path.exists(filename):
with open(f'_assets/{filename}', 'wb') as file:
file.write(requests.get(f'{base_url}/{filename}').content)
with open(f"_assets/{filename}", "wb") as file:
file.write(requests.get(f"{base_url}/{filename}").content)
######################################################################
# Generate the Ideal Ratio Mask (IRM)
......@@ -79,9 +79,9 @@ for filename in filenames:
# ~~~~~~~~~~~~~~~~~~
#
mix, sr = torchaudio.load('_assets/mix.wav')
reverb_clean, sr2 = torchaudio.load('_assets/reverb_clean.wav')
clean, sr3 = torchaudio.load('_assets/clean.wav')
mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load("_assets/clean.wav")
assert sr == sr2
noise = mix - reverb_clean
......@@ -125,8 +125,7 @@ spec_noise = stft(noise)
#
def get_irms(spec_clean, spec_noise, spec_mix):
mag_mix = spec_mix.abs() ** 2
def get_irms(spec_clean, spec_noise):
mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
......@@ -134,12 +133,13 @@ def get_irms(spec_clean, spec_noise, spec_mix):
return irm_speech, irm_noise
######################################################################
# .. note::
# We use reverberant clean speech as the target here,
# you can also set it to dry clean speech.
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
######################################################################
# Apply MVDR
......@@ -152,7 +152,7 @@ irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)
#
results_multi = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
stft_est = mvdr(spec_mix, irm_speech, irm_noise)
est = istft(stft_est, length=mix.shape[-1])
......@@ -166,8 +166,10 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
# The channel selection may depend on the design of the microphone array
results_single = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(
ref_channel=0, solution=solution, multi_mask=False
)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est
......@@ -177,6 +179,7 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
# ~~~~~~~~~~~~~~~~~~~~~
#
def si_sdr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
reference = reference - reference.mean()
......@@ -196,6 +199,7 @@ def si_sdr(estimate, reference, epsilon=1e-8):
sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return sisdr.item()
######################################################################
# Results
# -------
......@@ -207,7 +211,9 @@ def si_sdr(estimate, reference, epsilon=1e-8):
#
for solution in results_single:
print(solution+": ", si_sdr(results_single[solution][None,...], reverb_clean[0:1]))
print(
solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1])
)
######################################################################
# Multi-channel mask results
......@@ -215,7 +221,9 @@ for solution in results_single:
#
for solution in results_multi:
print(solution+": ", si_sdr(results_multi[solution][None,...], reverb_clean[0:1]))
print(
solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1])
)
######################################################################
# Original audio
......@@ -253,39 +261,39 @@ ipd.Audio(clean[0], rate=16000)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_multi['ref_channel'], rate=16000)
ipd.Audio(results_multi["ref_channel"], rate=16000)
######################################################################
# Multi-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_multi['stv_evd'], rate=16000)
ipd.Audio(results_multi["stv_evd"], rate=16000)
######################################################################
# Multi-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_multi['stv_power'], rate=16000)
ipd.Audio(results_multi["stv_power"], rate=16000)
######################################################################
# Single-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_single['ref_channel'], rate=16000)
ipd.Audio(results_single["ref_channel"], rate=16000)
######################################################################
# Single-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_single['stv_evd'], rate=16000)
ipd.Audio(results_single["stv_evd"], rate=16000)
######################################################################
# Single-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_single['stv_power'], rate=16000)
ipd.Audio(results_single["stv_power"], rate=16000)
......@@ -48,21 +48,21 @@ import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True)
with open(SPEECH_FILE, 'wb') as file:
os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content)
......@@ -241,6 +241,7 @@ print("Class labels:", bundle.get_labels())
# We start by defining greedy decoding algorithm.
#
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
......@@ -258,7 +259,7 @@ class GreedyCTCDecoder(torch.nn.Module):
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices])
return "".join([self.labels[i] for i in indices])
######################################################################
......
......@@ -63,7 +63,7 @@ import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -99,14 +99,16 @@ print(device)
# that are not in the table are ignored.
#
symbols = '_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'
symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"
look_up = {s: i for i, s in enumerate(symbols)}
symbols = set(symbols)
def text_to_sequence(text):
text = text.lower()
return [look_up[s] for s in text if s in symbols]
text = "Hello world! Text to speech!"
print(text_to_sequence(text))
......@@ -136,7 +138,7 @@ print(lengths)
# The intermediate representation can be retrieved as follow.
#
print([processor.tokens[i] for i in processed[0, :lengths[0]]])
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
######################################################################
......@@ -179,7 +181,7 @@ print(lengths)
# The intermediate representation looks like the following.
#
print([processor.tokens[i] for i in processed[0, :lengths[0]]])
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
######################################################################
......@@ -269,7 +271,9 @@ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
torchaudio.save(
"_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
IPython.display.Audio("_assets/output_wavernn.wav")
......@@ -298,7 +302,11 @@ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_griffinlim.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
torchaudio.save(
"_assets/output_griffinlim.wav",
waveforms[0:1].cpu(),
sample_rate=vocoder.sample_rate,
)
IPython.display.Audio("_assets/output_griffinlim.wav")
......@@ -313,9 +321,20 @@ IPython.display.Audio("_assets/output_griffinlim.wav")
# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp32', pretrained=False)
checkpoint = torch.hub.load_state_dict_from_url('https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth', progress=False, map_location=device)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub",
"nvidia_waveglow",
model_math="fp32",
pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501
progress=False,
map_location=device,
)
state_dict = {
key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment