Unverified Commit 0ac196c6 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Fix style checks in examples/tutorials (#2006)

parent e6dd7fd3
......@@ -22,16 +22,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import math
import os
......@@ -46,125 +47,135 @@ _SAMPLE_DIR = "_assets"
SAMPLE_WAV_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.wav"
SAMPLE_WAV_PATH = os.path.join(_SAMPLE_DIR, "steam.wav")
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav"
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav" # noqa: E501
SAMPLE_RIR_PATH = os.path.join(_SAMPLE_DIR, "rir.wav")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav" # noqa: E501
SAMPLE_NOISE_PATH = os.path.join(_SAMPLE_DIR, "bg.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_RIR_URL, SAMPLE_RIR_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_NOISE_URL, SAMPLE_NOISE_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
file_.write(requests.get(url).content)
uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_RIR_URL, SAMPLE_RIR_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_NOISE_URL, SAMPLE_NOISE_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
if resample:
effects.extend([
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def get_rir_sample(*, resample=None, processed=False):
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed:
return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
return rir, sample_rate
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed:
return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
return rir, sample_rate
def get_noise_sample(*, resample=None):
return _get_sample(SAMPLE_NOISE_PATH, resample=resample)
return _get_sample(SAMPLE_NOISE_PATH, resample=resample)
######################################################################
......@@ -208,20 +219,21 @@ waveform1, sample_rate1 = get_sample(resample=16000)
# Define effects
effects = [
["lowpass", "-1", "300"], # apply single-pole lowpass filter
["speed", "0.8"], # reduce the speed
# This only changes sample rate, so it is necessary to
# add `rate` effect with original sample rate after this.
["rate", f"{sample_rate1}"],
["reverb", "-w"], # Reverbration gives some dramatic feeling
["lowpass", "-1", "300"], # apply single-pole lowpass filter
["speed", "0.8"], # reduce the speed
# This only changes sample rate, so it is necessary to
# add `rate` effect with original sample rate after this.
["rate", f"{sample_rate1}"],
["reverb", "-w"], # Reverbration gives some dramatic feeling
]
# Apply effects
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
waveform1, sample_rate1, effects)
waveform1, sample_rate1, effects
)
plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-.1, 3.2))
plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-.1, 3.2))
plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-0.1, 3.2))
plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-0.1, 3.2))
print_stats(waveform1, sample_rate=sample_rate1, src="Original")
print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied")
......@@ -268,7 +280,7 @@ play_audio(rir_raw, sample_rate)
# the signal power, then flip along the time axis.
#
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
......@@ -281,7 +293,7 @@ plot_waveform(rir, sample_rate, title="Room Impulse Response", ylim=None)
speech, _ = get_speech_sample(resample=sample_rate)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
augmented = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_waveform(speech, sample_rate, title="Original", ylim=None)
......@@ -312,7 +324,7 @@ play_audio(augmented, sample_rate)
sample_rate = 8000
speech, _ = get_speech_sample(resample=sample_rate)
noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]]
noise = noise[:, : speech.shape[1]]
plot_waveform(noise, sample_rate, title="Background noise")
plot_specgram(noise, sample_rate, title="Background noise")
......@@ -322,13 +334,13 @@ speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
for snr_db in [20, 10, 3]:
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
noisy_speech = (scale * speech + noise) / 2
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
noisy_speech = (scale * speech + noise) / 2
plot_waveform(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
plot_specgram(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
play_audio(noisy_speech, sample_rate)
plot_waveform(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
plot_specgram(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
play_audio(noisy_speech, sample_rate)
######################################################################
# Applying codec to Tensor object
......@@ -346,15 +358,15 @@ plot_specgram(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate)
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis"),
]
for param, title in configs:
augmented = F.apply_codec(waveform, sample_rate, **param)
plot_specgram(augmented, sample_rate, title=title)
play_audio(augmented, sample_rate)
augmented = F.apply_codec(waveform, sample_rate, **param)
plot_specgram(augmented, sample_rate, title=title)
play_audio(augmented, sample_rate)
######################################################################
# Simulating a phone recoding
......@@ -373,7 +385,7 @@ play_audio(speech, sample_rate)
# Apply RIR
rir, _ = get_rir_sample(resample=sample_rate, processed=True)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_specgram(speech, sample_rate, title="RIR Applied")
......@@ -384,7 +396,7 @@ play_audio(speech, sample_rate)
# the noise contains the acoustic feature of the environment. Therefore, we add
# the noise after RIR application.
noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]]
noise = noise[:, : speech.shape[1]]
snr_db = 8
scale = math.exp(snr_db / 10) * noise.norm(p=2) / speech.norm(p=2)
......@@ -395,13 +407,20 @@ play_audio(speech, sample_rate)
# Apply filtering and change sample rate
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech,
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
["rate", "8000"],
],
speech,
sample_rate,
effects=[
["lowpass", "4000"],
[
"compand",
"0.02,0.05",
"-60,-60,-30,-10,-20,-8,-5,-8,-2,-8",
"-8",
"-7",
"0.05",
],
["rate", "8000"],
],
)
plot_specgram(speech, sample_rate, title="Filtered")
......
......@@ -23,14 +23,14 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import multiprocessing
import os
......@@ -42,52 +42,57 @@ _SAMPLE_DIR = "_assets"
YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def _download_yesno():
if os.path.exists(os.path.join(YESNO_DATASET_PATH, "waves_yesno.tar.gz")):
return
torchaudio.datasets.YESNO(root=YESNO_DATASET_PATH, download=True)
if os.path.exists(os.path.join(YESNO_DATASET_PATH, "waves_yesno.tar.gz")):
return
torchaudio.datasets.YESNO(root=YESNO_DATASET_PATH, download=True)
YESNO_DOWNLOAD_PROCESS = multiprocessing.Process(target=_download_yesno)
YESNO_DOWNLOAD_PROCESS.start()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
######################################################################
# Here, we show how to use the ``YESNO`` dataset.
#
YESNO_DOWNLOAD_PROCESS.join()
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
......@@ -20,16 +20,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import os
import requests
......@@ -40,62 +41,69 @@ import matplotlib.pyplot as plt
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
file_.write(requests.get(url).content)
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
if resample:
effects.extend([
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def get_spectrogram(
n_fft = 400,
win_len = None,
hop_len = None,
power = 2.0,
n_fft=400,
win_len=None,
hop_len=None,
power=2.0,
):
waveform, _ = get_speech_sample()
spectrogram = T.Spectrogram(
n_fft=n_fft,
win_length=win_len,
hop_length=hop_len,
center=True,
pad_mode="reflect",
power=power,
)
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Spectrogram (db)')
axs.set_ylabel(ylabel)
axs.set_xlabel('frame')
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
waveform, _ = get_speech_sample()
spectrogram = T.Spectrogram(
n_fft=n_fft,
win_length=win_len,
hop_length=hop_len,
center=True,
pad_mode="reflect",
power=power,
)
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
######################################################################
# SpecAugment
......@@ -111,18 +119,23 @@ def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=No
# ~~~~~~~~~~~
#
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect='equal', xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
######################################################################
# TimeMasking
......
......@@ -38,16 +38,17 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import os
import requests
......@@ -59,143 +60,154 @@ from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
file_.write(requests.get(url).content)
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
if resample:
effects.extend([
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Spectrogram (db)')
axs.set_ylabel(ylabel)
axs.set_xlabel('frame')
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_mel_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Filter bank')
axs.imshow(fbank, aspect='auto')
axs.set_ylabel('frequency bin')
axs.set_xlabel('mel bin')
plt.show(block=False)
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
plt.show(block=False)
def plot_pitch(waveform, sample_rate, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln2 = axis2.plot(
time_axis, pitch[0], linewidth=2, label='Pitch', color='green')
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
axis2.legend(loc=0)
plt.show(block=False)
def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label='Pitch', color='green')
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(
time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--"
)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(
time_axis, nfcc[0], linewidth=2, label='NFCC', color='blue', linestyle='--')
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
######################################################################
# Spectrogram
......@@ -206,7 +218,6 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
#
waveform, sample_rate = get_speech_sample()
n_fft = 1024
......@@ -226,7 +237,7 @@ spectrogram = T.Spectrogram(
spec = spectrogram(waveform)
print_stats(spec)
plot_spectrogram(spec[0], title='torchaudio')
plot_spectrogram(spec[0], title="torchaudio")
######################################################################
# GriffinLim
......@@ -280,10 +291,10 @@ sample_rate = 6000
mel_filters = F.melscale_fbanks(
int(n_fft // 2 + 1),
n_mels=n_mels,
f_min=0.,
f_max=sample_rate/2.,
f_min=0.0,
f_max=sample_rate / 2.0,
sample_rate=sample_rate,
norm='slaney'
norm="slaney",
)
plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
......@@ -300,16 +311,16 @@ mel_filters_librosa = librosa.filters.mel(
sample_rate,
n_fft,
n_mels=n_mels,
fmin=0.,
fmax=sample_rate/2.,
norm='slaney',
fmin=0.0,
fmax=sample_rate / 2.0,
norm="slaney",
htk=True,
).T
plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
mse = torch.square(mel_filters - mel_filters_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# MelSpectrogram
......@@ -336,15 +347,14 @@ mel_spectrogram = T.MelSpectrogram(
center=True,
pad_mode="reflect",
power=2.0,
norm='slaney',
norm="slaney",
onesided=True,
n_mels=n_mels,
mel_scale="htk",
)
melspec = mel_spectrogram(waveform)
plot_spectrogram(
melspec[0], title="MelSpectrogram - torchaudio", ylabel='mel freq')
plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq")
######################################################################
# Comparison against librosa
......@@ -365,14 +375,13 @@ melspec_librosa = librosa.feature.melspectrogram(
pad_mode="reflect",
power=2.0,
n_mels=n_mels,
norm='slaney',
norm="slaney",
htk=True,
)
plot_spectrogram(
melspec_librosa, title="MelSpectrogram - librosa", ylabel='mel freq')
plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq")
mse = torch.square(melspec - melspec_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# MFCC
......@@ -391,11 +400,11 @@ mfcc_transform = T.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': n_fft,
'n_mels': n_mels,
'hop_length': hop_length,
'mel_scale': 'htk',
}
"n_fft": n_fft,
"n_mels": n_mels,
"hop_length": hop_length,
"mel_scale": "htk",
},
)
mfcc = mfcc_transform(waveform)
......@@ -409,18 +418,27 @@ plot_spectrogram(mfcc[0])
melspec = librosa.feature.melspectrogram(
y=waveform.numpy()[0], sr=sample_rate, n_fft=n_fft,
win_length=win_length, hop_length=hop_length,
n_mels=n_mels, htk=True, norm=None)
y=waveform.numpy()[0],
sr=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
)
mfcc_librosa = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec),
n_mfcc=n_mfcc, dct_type=2, norm='ortho')
S=librosa.core.spectrum.power_to_db(melspec),
n_mfcc=n_mfcc,
dct_type=2,
norm="ortho",
)
plot_spectrogram(mfcc_librosa)
mse = torch.square(mfcc - mfcc_librosa).mean().item()
print('Mean Square Difference: ', mse)
print("Mean Square Difference: ", mse)
######################################################################
# Pitch
......
......@@ -21,12 +21,13 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
#@markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
import io
......@@ -51,7 +52,7 @@ SAMPLE_MP3_PATH = os.path.join(_SAMPLE_DIR, "steam.mp3")
SAMPLE_GSM_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.gsm"
SAMPLE_GSM_PATH = os.path.join(_SAMPLE_DIR, "steam.gsm")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_TAR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit.tar.gz"
......@@ -61,117 +62,127 @@ SAMPLE_TAR_ITEM = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp
S3_BUCKET = "pytorch-tutorial-assets"
S3_KEY = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
def _fetch_data():
os.makedirs(_SAMPLE_DIR, exist_ok=True)
uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_MP3_URL, SAMPLE_MP3_PATH),
(SAMPLE_GSM_URL, SAMPLE_GSM_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_TAR_URL, SAMPLE_TAR_PATH),
]
for url, path in uri:
with open(path, 'wb') as file_:
file_.write(requests.get(url).content)
os.makedirs(_SAMPLE_DIR, exist_ok=True)
uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_MP3_URL, SAMPLE_MP3_PATH),
(SAMPLE_GSM_URL, SAMPLE_GSM_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_TAR_URL, SAMPLE_TAR_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def _get_sample(path, resample=None):
effects = [
["remix", "1"]
]
if resample:
effects.extend([
["lowpass", f"{resample // 2}"],
["rate", f'{resample}'],
])
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def inspect_file(path):
print("-" * 10)
print("Source:", path)
print("-" * 10)
print(f" - File size: {os.path.getsize(path)} bytes")
print(f" - {torchaudio.info(path)}")
print("-" * 10)
print("Source:", path)
print("-" * 10)
print(f" - File size: {os.path.getsize(path)} bytes")
print(f" - {torchaudio.info(path)}")
######################################################################
# Quering audio metadata
# ----------------------
# Querying audio metadata
# -----------------------
#
# Function ``torchaudio.info`` fetches audio metadata. You can provide
# a path-like object or file-like object.
#
metadata = torchaudio.info(SAMPLE_WAV_PATH)
print(metadata)
......@@ -231,7 +242,7 @@ print(metadata)
print("Source:", SAMPLE_WAV_URL)
with requests.get(SAMPLE_WAV_URL, stream=True) as response:
metadata = torchaudio.info(response.raw)
metadata = torchaudio.info(response.raw)
print(metadata)
######################################################################
......@@ -248,9 +259,9 @@ print(metadata)
print("Source:", SAMPLE_MP3_URL)
with requests.get(SAMPLE_MP3_URL, stream=True) as response:
metadata = torchaudio.info(response.raw, format="mp3")
metadata = torchaudio.info(response.raw, format="mp3")
print(f"Fetched {response.raw.tell()} bytes.")
print(f"Fetched {response.raw.tell()} bytes.")
print(metadata)
######################################################################
......@@ -291,19 +302,19 @@ play_audio(waveform, sample_rate)
# Load audio data as HTTP request
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform, sample_rate = torchaudio.load(response.raw)
waveform, sample_rate = torchaudio.load(response.raw)
plot_specgram(waveform, sample_rate, title="HTTP datasource")
# Load audio from tar file
with tarfile.open(SAMPLE_TAR_PATH, mode='r') as tarfile_:
fileobj = tarfile_.extractfile(SAMPLE_TAR_ITEM)
waveform, sample_rate = torchaudio.load(fileobj)
with tarfile.open(SAMPLE_TAR_PATH, mode="r") as tarfile_:
fileobj = tarfile_.extractfile(SAMPLE_TAR_ITEM)
waveform, sample_rate = torchaudio.load(fileobj)
plot_specgram(waveform, sample_rate, title="TAR file")
# Load audio from S3
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
response = client.get_object(Bucket=S3_BUCKET, Key=S3_KEY)
waveform, sample_rate = torchaudio.load(response['Body'])
waveform, sample_rate = torchaudio.load(response["Body"])
plot_specgram(waveform, sample_rate, title="From S3")
......@@ -336,15 +347,16 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds
print("Fetching all the data...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset:frame_offset+num_frames]
print(f" - Fetched {response.raw.tell()} bytes")
waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset: frame_offset + num_frames]
print(f" - Fetched {response.raw.tell()} bytes")
print("Fetching until the requested frames are available...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load(
response.raw, frame_offset=frame_offset, num_frames=num_frames)
print(f" - Fetched {response.raw.tell()} bytes")
waveform2, sample_rate2 = torchaudio.load(
response.raw, frame_offset=frame_offset, num_frames=num_frames
)
print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="")
assert (waveform1 == waveform2).all()
......@@ -389,9 +401,7 @@ inspect_file(path)
# Save as 16-bit signed integer Linear PCM
# The resulting file occupies half the storage but loses precision
path = f"{_SAMPLE_DIR}/save_example_PCM_S16.wav"
torchaudio.save(
path, waveform, sample_rate,
encoding="PCM_S", bits_per_sample=16)
torchaudio.save(path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
inspect_file(path)
......@@ -402,19 +412,19 @@ inspect_file(path)
waveform, sample_rate = get_sample(resample=8000)
formats = [
"mp3",
"flac",
"vorbis",
"sph",
"amb",
"amr-nb",
"gsm",
"mp3",
"flac",
"vorbis",
"sph",
"amb",
"amr-nb",
"gsm",
]
for format in formats:
path = f"{_SAMPLE_DIR}/save_example.{format}"
torchaudio.save(path, waveform, sample_rate, format=format)
inspect_file(path)
path = f"{_SAMPLE_DIR}/save_example.{format}"
torchaudio.save(path, waveform, sample_rate, format=format)
inspect_file(path)
######################################################################
......@@ -435,4 +445,3 @@ torchaudio.save(buffer_, waveform, sample_rate, format="wav")
buffer_.seek(0)
print(buffer_.read(16))
......@@ -24,14 +24,14 @@ print(torchaudio.__version__)
# --------------------------------------------------------
#
#@title Prepare data and utility functions. {display-mode: "form"}
#@markdown
#@markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go.
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
import math
import time
......@@ -46,95 +46,108 @@ DEFAULT_OFFSET = 201
SWEEP_MAX_SAMPLE_RATE = 48000
DEFAULT_LOWPASS_FILTER_WIDTH = 6
DEFAULT_ROLLOFF = 0.99
DEFAULT_RESAMPLING_METHOD = 'sinc_interpolation'
DEFAULT_RESAMPLING_METHOD = "sinc_interpolation"
def _get_log_freq(sample_rate, max_sweep_rate, offset):
"""Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
"""Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
offset is used to avoid negative infinity `log(offset + x)`.
offset is used to avoid negative infinity `log(offset + x)`.
"""
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return (
torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
)
"""
half = sample_rate // 2
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
def _get_inverse_log_freq(freq, sample_rate, offset):
"""Find the time where the given frequency is given by _get_log_freq"""
half = sample_rate // 2
return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))
"""Find the time where the given frequency is given by _get_log_freq"""
half = sample_rate // 2
return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))
def _get_freq_ticks(sample_rate, offset, f_max):
# Given the original sample rate used for generating the sweep,
# find the x-axis value where the log-scale major frequency values fall in
time, freq = [], []
for exp in range(2, 5):
for v in range(1, 10):
f = v * 10 ** exp
if f < sample_rate // 2:
t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
time.append(t)
freq.append(f)
t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
time.append(t_max)
freq.append(f_max)
return time, freq
# Given the original sample rate used for generating the sweep,
# find the x-axis value where the log-scale major frequency values fall in
time, freq = [], []
for exp in range(2, 5):
for v in range(1, 10):
f = v * 10 ** exp
if f < sample_rate // 2:
t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
time.append(t)
freq.append(f)
t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
time.append(t_max)
freq.append(f_max)
return time, freq
def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
max_sweep_rate = sample_rate
freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
delta = 2 * math.pi * freq / sample_rate
cummulative = torch.cumsum(delta, dim=0)
signal = torch.sin(cummulative).unsqueeze(dim=0)
return signal
def plot_sweep(waveform, sample_rate, title, max_sweep_rate=SWEEP_MAX_SAMPLE_RATE, offset=DEFAULT_OFFSET):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
y_ticks = [1000, 5000, 10000, 20000, sample_rate//2]
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2]
figure, axis = plt.subplots(1, 1)
axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
axis.set_xlabel('Original Signal Frequency (Hz, log scale)')
axis.set_ylabel('Waveform Frequency (Hz)')
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f'{title} (sample rate: {sample_rate} Hz)')
plt.show(block=True)
max_sweep_rate = sample_rate
freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
delta = 2 * math.pi * freq / sample_rate
cummulative = torch.cumsum(delta, dim=0)
signal = torch.sin(cummulative).unsqueeze(dim=0)
return signal
def plot_sweep(
waveform,
sample_rate,
title,
max_sweep_rate=SWEEP_MAX_SAMPLE_RATE,
offset=DEFAULT_OFFSET,
):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2]
figure, axis = plt.subplots(1, 1)
axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
axis.set_ylabel("Waveform Frequency (Hz)")
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.show(block=True)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}')
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def benchmark_resample(
method,
......@@ -146,30 +159,45 @@ def benchmark_resample(
resampling_method=DEFAULT_RESAMPLING_METHOD,
beta=None,
librosa_type=None,
iters=5
iters=5,
):
if method == "functional":
begin = time.time()
for _ in range(iters):
F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff, resampling_method=resampling_method)
elapsed = time.time() - begin
return elapsed / iters
elif method == "transforms":
resampler = T.Resample(sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff, resampling_method=resampling_method, dtype=waveform.dtype)
begin = time.time()
for _ in range(iters):
resampler(waveform)
elapsed = time.time() - begin
return elapsed / iters
elif method == "librosa":
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(waveform_np, sample_rate, resample_rate, res_type=librosa_type)
elapsed = time.time() - begin
return elapsed / iters
if method == "functional":
begin = time.time()
for _ in range(iters):
F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
)
elapsed = time.time() - begin
return elapsed / iters
elif method == "transforms":
resampler = T.Resample(
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
dtype=waveform.dtype,
)
begin = time.time()
for _ in range(iters):
resampler(waveform)
elapsed = time.time() - begin
return elapsed / iters
elif method == "librosa":
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(
waveform_np, sample_rate, resample_rate, res_type=librosa_type
)
elapsed = time.time() - begin
return elapsed / iters
######################################################################
# To resample an audio waveform from one freqeuncy to another, you can use
......@@ -202,6 +230,7 @@ def benchmark_resample(
# plotted waveform, and color intensity the amplitude.
#
sample_rate = 48000
resample_rate = 32000
......@@ -235,10 +264,14 @@ play_audio(waveform, sample_rate)
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=6
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=128
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
......@@ -282,10 +315,14 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation"
)
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="kaiser_window"
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
......@@ -301,7 +338,7 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
sample_rate = 48000
resample_rate = 32000
### kaiser_best
# kaiser_best
resampled_waveform = F.resample(
waveform,
sample_rate,
......@@ -309,18 +346,23 @@ resampled_waveform = F.resample(
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492
beta=14.769656459379492,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_best')).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)
### kaiser_fast
# kaiser_fast
resampled_waveform = F.resample(
waveform,
sample_rate,
......@@ -328,13 +370,20 @@ resampled_waveform = F.resample(
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386
beta=8.555504641634386,
)
plot_specgram(
resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)"
)
plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_fast')).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)
......@@ -371,71 +420,87 @@ configs = {
}
for label in configs:
times, rows = [], []
sample_rate = configs[label][0]
resample_rate = configs[label][1]
waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append(f"sinc (width 64)")
# sinc 6 zero-crossings
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append(f"sinc (width 16)")
# kaiser best
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
f_time = benchmark_resample(
"functional",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492)
t_time = benchmark_resample(
"transforms",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append(f"kaiser_best")
# kaiser fast
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
f_time = benchmark_resample(
"functional",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386)
t_time = benchmark_resample(
"transforms",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append(f"kaiser_fast")
df = pd.DataFrame(times,
columns=["librosa", "functional", "transforms"],
index=rows)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"],df.columns])
display(df.round(2))
times, rows = [], []
sample_rate = configs[label][0]
resample_rate = configs[label][1]
waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 64)")
# sinc 6 zero-crossings
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 16)")
# kaiser best
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best"
)
f_time = benchmark_resample(
"functional",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492,
)
t_time = benchmark_resample(
"transforms",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492,
)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append("kaiser_best")
# kaiser fast
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast"
)
f_time = benchmark_resample(
"functional",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
t_time = benchmark_resample(
"transforms",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append("kaiser_fast")
df = pd.DataFrame(
times, columns=["librosa", "functional", "transforms"], index=rows
)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
display(df.round(2))
......@@ -9,11 +9,11 @@ MVDR with torchaudio
######################################################################
# Overview
# --------
#
#
# This is a tutorial on how to apply MVDR beamforming by using `torchaudio <https://github.com/pytorch/audio>`__.
#
#
# Steps
#
#
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
# magnitude by the mixture magnitude.
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``)
......@@ -26,22 +26,22 @@ MVDR with torchaudio
######################################################################
# Preparation
# -----------
#
#
# First, we import the necessary packages and retrieve the data.
#
#
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
#
#
# The original filename is
#
#
# ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
#
#
# which was generated with;
#
#
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0)
#
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # noqa: E501
#
import os
import requests
......@@ -50,48 +50,48 @@ import torchaudio
import IPython.display as ipd
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
filenames = [
'mix.wav',
'reverb_clean.wav',
'clean.wav',
"mix.wav",
"reverb_clean.wav",
"clean.wav",
]
base_url = 'https://download.pytorch.org/torchaudio/tutorial-assets/mvdr'
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
for filename in filenames:
os.makedirs('_assets', exist_ok=True)
os.makedirs("_assets", exist_ok=True)
if not os.path.exists(filename):
with open(f'_assets/{filename}', 'wb') as file:
file.write(requests.get(f'{base_url}/{filename}').content)
with open(f"_assets/{filename}", "wb") as file:
file.write(requests.get(f"{base_url}/{filename}").content)
######################################################################
# Generate the Ideal Ratio Mask (IRM)
# -----------------------------------
#
#
######################################################################
# Loading audio data
# ~~~~~~~~~~~~~~~~~~
#
#
mix, sr = torchaudio.load('_assets/mix.wav')
reverb_clean, sr2 = torchaudio.load('_assets/reverb_clean.wav')
clean, sr3 = torchaudio.load('_assets/clean.wav')
mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load("_assets/clean.wav")
assert sr == sr2
noise = mix - reverb_clean
######################################################################
#
#
# .. note::
# The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
# We need to convert the dtype of the waveforms to ``torch.double``
#
#
mix = mix.to(torch.double)
noise = noise.to(torch.double)
......@@ -101,7 +101,7 @@ reverb_clean = reverb_clean.to(torch.double)
######################################################################
# Compute STFT
# ~~~~~~~~~~~~
#
#
stft = torchaudio.transforms.Spectrogram(
n_fft=1024,
......@@ -118,15 +118,14 @@ spec_noise = stft(noise)
######################################################################
# Generate the Ideal Ratio Mask (IRM)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
# .. note::
# We found using the mask directly peforms better than using the
# square root of it. This is slightly different from the definition of IRM.
#
#
def get_irms(spec_clean, spec_noise, spec_mix):
mag_mix = spec_mix.abs() ** 2
def get_irms(spec_clean, spec_noise):
mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
......@@ -134,25 +133,26 @@ def get_irms(spec_clean, spec_noise, spec_mix):
return irm_speech, irm_noise
######################################################################
# .. note::
# We use reverberant clean speech as the target here,
# you can also set it to dry clean speech.
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
######################################################################
# Apply MVDR
# ----------
#
#
######################################################################
# Apply MVDR beamforming by using multi-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
results_multi = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
stft_est = mvdr(spec_mix, irm_speech, irm_noise)
est = istft(stft_est, length=mix.shape[-1])
......@@ -161,13 +161,15 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
######################################################################
# Apply MVDR beamforming by using single-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
# We use the 1st channel as an example.
# The channel selection may depend on the design of the microphone array
results_single = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(
ref_channel=0, solution=solution, multi_mask=False
)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est
......@@ -175,7 +177,8 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
######################################################################
# Compute Si-SDR scores
# ~~~~~~~~~~~~~~~~~~~~~
#
#
def si_sdr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
......@@ -196,96 +199,101 @@ def si_sdr(estimate, reference, epsilon=1e-8):
sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return sisdr.item()
######################################################################
# Results
# -------
#
#
######################################################################
# Single-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
for solution in results_single:
print(solution+": ", si_sdr(results_single[solution][None,...], reverb_clean[0:1]))
print(
solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1])
)
######################################################################
# Multi-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
for solution in results_multi:
print(solution+": ", si_sdr(results_multi[solution][None,...], reverb_clean[0:1]))
print(
solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1])
)
######################################################################
# Original audio
# --------------
#
#
######################################################################
# Mixture speech
# ~~~~~~~~~~~~~~
#
#
ipd.Audio(mix[0], rate=16000)
######################################################################
# Noise
# ~~~~~
#
#
ipd.Audio(noise[0], rate=16000)
######################################################################
# Clean speech
# ~~~~~~~~~~~~
#
#
ipd.Audio(clean[0], rate=16000)
######################################################################
# Enhanced audio
# --------------
#
#
######################################################################
# Multi-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_multi['ref_channel'], rate=16000)
ipd.Audio(results_multi["ref_channel"], rate=16000)
######################################################################
# Multi-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_multi['stv_evd'], rate=16000)
ipd.Audio(results_multi["stv_evd"], rate=16000)
######################################################################
# Multi-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_multi['stv_power'], rate=16000)
ipd.Audio(results_multi["stv_power"], rate=16000)
######################################################################
# Single-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_single['ref_channel'], rate=16000)
ipd.Audio(results_single["ref_channel"], rate=16000)
######################################################################
# Single-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_single['stv_evd'], rate=16000)
ipd.Audio(results_single["stv_evd"], rate=16000)
######################################################################
# Single-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
ipd.Audio(results_single['stv_power'], rate=16000)
ipd.Audio(results_single["stv_power"], rate=16000)
......@@ -14,28 +14,28 @@ pre-trained models from wav2vec 2.0
######################################################################
# Overview
# --------
#
#
# The process of speech recognition looks like the following.
#
#
# 1. Extract the acoustic features from audio waveform
#
#
# 2. Estimate the class of the acoustic features frame-by-frame
#
#
# 3. Generate hypothesis from the sequence of the class probabilities
#
#
# Torchaudio provides easy access to the pre-trained weights and
# associated information, such as the expected sample rate and class
# labels. They are bundled together and available under
# ``torchaudio.pipelines`` module.
#
#
######################################################################
# Preparation
# -----------
#
#
# First we import the necessary packages, and fetch data that we work on.
#
#
# %matplotlib inline
......@@ -48,52 +48,52 @@ import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True)
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content)
######################################################################
# Creating a pipeline
# -------------------
#
#
# First, we will create a Wav2Vec2 model that performs the feature
# extraction and the classification.
#
#
# There are two types of Wav2Vec2 pre-trained weights available in
# torchaudio. The ones fine-tuned for ASR task, and the ones not
# fine-tuned.
#
#
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
# are firstly trained with audio only for representation learning, then
# fine-tuned for a specific task with additional labels.
#
#
# The pre-trained weights without fine-tuning can be fine-tuned
# for other downstream tasks as well, but this tutorial does not
# cover that.
#
#
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
#
#
# There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained.
#
#
# The bundle object provides the interface to instantiate model and other
# information. Sampling rate and the class labels are found as follow.
#
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
......@@ -105,7 +105,7 @@ print("Labels:", bundle.get_labels())
######################################################################
# Model can be constructed as following. This process will automatically
# fetch the pre-trained weights and load it into the model.
#
#
model = bundle.get_model().to(device)
......@@ -115,62 +115,62 @@ print(model.__class__)
######################################################################
# Loading data
# ------------
#
#
# We will use the speech data from `VOiCES
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
# Creative Commos BY 4.0.
#
#
IPython.display.Audio(SPEECH_FILE)
######################################################################
# To load data, we use :py:func:`torchaudio.load`.
#
#
# If the sampling rate is different from what the pipeline expects, then
# we can use :py:func:`torchaudio.functional.resample` for resampling.
#
#
# .. note::
#
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace.
#
#
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
######################################################################
# Extracting acoustic features
# ----------------------------
#
#
# The next step is to extract acoustic features from the audio.
#
#
# .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here.
#
#
with torch.inference_mode():
features, _ = model.extract_features(waveform)
features, _ = model.extract_features(waveform)
######################################################################
# The returned features is a list of tensors. Each tensor is the output of
# a transformer layer.
#
#
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
......@@ -178,24 +178,24 @@ plt.show()
######################################################################
# Feature classification
# ----------------------
#
#
# Once the acoustic features are extracted, the next step is to classify
# them into a set of categories.
#
#
# Wav2Vec2 model provides method to perform the feature extraction and
# classification in one step.
#
#
with torch.inference_mode():
emission, _ = model(waveform)
emission, _ = model(waveform)
######################################################################
# The output is in the form of logits. It is not in the form of
# probability.
#
#
# Let’s visualize this.
#
#
plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
......@@ -208,62 +208,63 @@ print("Class labels:", bundle.get_labels())
######################################################################
# We can see that there are strong indications to certain labels across
# the time line.
#
#
######################################################################
# Generating transcripts
# ----------------------
#
#
# From the sequence of label probabilities, now we want to generate
# transcripts. The process to generate hypotheses is often called
# “decoding”.
#
#
# Decoding is more elaborate than simple classification because
# decoding at certain time step can be affected by surrounding
# observations.
#
#
# For example, take a word like ``night`` and ``knight``. Even if their
# prior probability distribution are differnt (in typical conversations,
# ``night`` would occur way more often than ``knight``), to accurately
# generate transcripts with ``knight``, such as ``a knight with a sword``,
# the decoding process has to postpone the final decision until it sees
# enough context.
#
#
# There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models.
#
#
# In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated.
#
#
# We start by defining greedy decoding algorithm.
#
#
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices])
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return "".join([self.labels[i] for i in indices])
######################################################################
# Now create the decoder object and decode the transcript.
#
#
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
......@@ -271,7 +272,7 @@ transcript = decoder(emission[0])
######################################################################
# Let’s check the result and listen again to the audio.
#
#
print(transcript)
IPython.display.Audio(SPEECH_FILE)
......@@ -283,19 +284,19 @@ IPython.display.Audio(SPEECH_FILE)
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In
# decoding, these are simply ignored.
#
#
######################################################################
# Conclusion
# ----------
#
#
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines.
#
#
# ::
#
#
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
# emission = model(waveforms, ...)
#
#
......@@ -10,24 +10,24 @@ Text-to-Speech with Tacotron2
######################################################################
# Overview
# --------
#
#
# This tutorial shows how to build text-to-speech pipeline, using the
# pretrained Tacotron2 in torchaudio.
#
#
# The text-to-speech pipeline goes as follows:
#
#
# 1. Text preprocessing
#
#
# First, the input text is encoded into a list of symbols. In this
# tutorial, we will use English characters and phonemes as the symbols.
#
#
# 2. Spectrogram generation
#
#
# From the encoded text, a spectrogram is generated. We use ``Tacotron2``
# model for this.
#
#
# 3. Time-domain conversion
#
#
# The last step is converting the spectrogram into the waveform. The
# process to generate speech from spectrogram is also called Vocoder.
# In this tutorial, three different vocoders are used,
......@@ -35,23 +35,23 @@ Text-to-Speech with Tacotron2
# `Griffin-Lim <https://pytorch.org/audio/stable/transforms.html#griffinlim>`__,
# and
# `Nvidia's WaveGlow <https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/>`__.
#
#
#
#
# The following figure illustrates the whole process.
#
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png
#
#
# All the related components are bundled in :py:func:`torchaudio.pipelines.Tacotron2TTSBundle`,
# but this tutorial will also cover the process under the hood.
######################################################################
# Preparation
# -----------
#
#
# First, we install the necessary dependencies. In addition to
# ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based
# encoding.
#
#
# When running this example in notebook, install DeepPhonemizer
# !pip3 install deep_phonemizer
......@@ -63,7 +63,7 @@ import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -76,36 +76,38 @@ print(device)
######################################################################
# Text Processing
# ---------------
#
#
######################################################################
# Character-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
#
# In this section, we will go through how the character-based encoding
# works.
#
#
# Since the pre-trained Tacotron2 model expects specific set of symbol
# tables, the same functionalities available in ``torchaudio``. This
# section is more for the explanation of the basis of encoding.
#
#
# Firstly, we define the set of symbols. For example, we can use
# ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the
# each character of the input text into the index of the corresponding
# symbol in the table.
#
#
# The following is an example of such processing. In the example, symbols
# that are not in the table are ignored.
#
#
symbols = '_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'
symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"
look_up = {s: i for i, s in enumerate(symbols)}
symbols = set(symbols)
def text_to_sequence(text):
text = text.lower()
return [look_up[s] for s in text if s in symbols]
text = text.lower()
return [look_up[s] for s in text if s in symbols]
text = "Hello world! Text to speech!"
print(text_to_sequence(text))
......@@ -116,7 +118,7 @@ print(text_to_sequence(text))
# what the pretrained Tacotron2 model expects. ``torchaudio`` provides the
# transform along with the pretrained model. For example, you can
# instantiate and use such transform as follow.
#
#
processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()
......@@ -132,33 +134,33 @@ print(lengths)
# When a list of texts are provided, the returned ``lengths`` variable
# represents the valid length of each processed tokens in the output
# batch.
#
#
# The intermediate representation can be retrieved as follow.
#
#
print([processor.tokens[i] for i in processed[0, :lengths[0]]])
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
######################################################################
# Phoneme-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~
#
#
# Phoneme-based encoding is similar to character-based encoding, but it
# uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme)
# model.
#
#
# The detail of the G2P model is out of scope of this tutorial, we will
# just look at what the conversion looks like.
#
#
# Similar to the case of character-based encoding, the encoding process is
# expected to match what a pretrained Tacotron2 model is trained on.
# ``torchaudio`` has an interface to create the process.
#
#
# The following code illustrates how to make and use the process. Behind
# the scene, a G2P model is created using ``DeepPhonemizer`` package, and
# the pretrained weights published by the author of ``DeepPhonemizer`` is
# fetched.
#
#
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
......@@ -166,7 +168,7 @@ processor = bundle.get_text_processor()
text = "Hello world! Text to speech!"
with torch.inference_mode():
processed, lengths = processor(text)
processed, lengths = processor(text)
print(processed)
print(lengths)
......@@ -175,30 +177,30 @@ print(lengths)
######################################################################
# Notice that the encoded values are different from the example of
# character-based encoding.
#
#
# The intermediate representation looks like the following.
#
#
print([processor.tokens[i] for i in processed[0, :lengths[0]]])
print([processor.tokens[i] for i in processed[0, : lengths[0]]])
######################################################################
# Spectrogram Generation
# ----------------------
#
#
# ``Tacotron2`` is the model we use to generate spectrogram from the
# encoded text. For the detail of the model, please refer to `the
# paper <https://arxiv.org/abs/1712.05884>`__.
#
#
# It is easy to instantiate a Tacotron2 model with pretrained weight,
# however, note that the input to Tacotron2 models need to be processed
# by the matching text processor.
#
#
# :py:func:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching
# models and processors together so that it is easy to create the pipeline.
#
#
# For the available bundles, and its usage, please refer to :py:mod:`torchaudio.pipelines`.
#
#
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
......@@ -207,10 +209,10 @@ tacotron2 = bundle.get_tacotron2().to(device)
text = "Hello world! Text to speech!"
with torch.inference_mode():
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, _, _ = tacotron2.infer(processed, lengths)
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, _, _ = tacotron2.infer(processed, lengths)
plt.imshow(spec[0].cpu().detach())
......@@ -219,36 +221,36 @@ plt.imshow(spec[0].cpu().detach())
######################################################################
# Note that ``Tacotron2.infer`` method perfoms multinomial sampling,
# therefor, the process of generating the spectrogram incurs randomness.
#
#
fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
for i in range(3):
with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
print(spec[0].shape)
ax[i].imshow(spec[0].cpu().detach())
with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
print(spec[0].shape)
ax[i].imshow(spec[0].cpu().detach())
plt.show()
######################################################################
# Waveform Generation
# -------------------
#
#
# Once the spectrogram is generated, the last process is to recover the
# waveform from the spectrogram.
#
#
# ``torchaudio`` provides vocoders based on ``GriffinLim`` and
# ``WaveRNN``.
#
#
######################################################################
# WaveRNN
# ~~~~~~~
#
#
# Continuing from the previous section, we can instantiate the matching
# WaveRNN model from the same bundle.
#
#
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
......@@ -259,27 +261,29 @@ vocoder = bundle.get_vocoder().to(device)
text = "Hello world! Text to speech!"
with torch.inference_mode():
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
torchaudio.save(
"_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
IPython.display.Audio("_assets/output_wavernn.wav")
######################################################################
# Griffin-Lim
# ~~~~~~~~~~~
#
#
# Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate
# the vocode object with ``get_vocoder`` method and pass the spectrogram.
#
#
bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH
......@@ -288,34 +292,49 @@ tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)
with torch.inference_mode():
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_griffinlim.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
torchaudio.save(
"_assets/output_griffinlim.wav",
waveforms[0:1].cpu(),
sample_rate=vocoder.sample_rate,
)
IPython.display.Audio("_assets/output_griffinlim.wav")
######################################################################
# Waveglow
# ~~~~~~~~
#
#
# Waveglow is a vocoder published by Nvidia. The pretrained weight is
# publishe on Torch Hub. One can instantiate the model using ``torch.hub``
# module.
#
#
# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp32', pretrained=False)
checkpoint = torch.hub.load_state_dict_from_url('https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth', progress=False, map_location=device)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub",
"nvidia_waveglow",
model_math="fp32",
pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501
progress=False,
map_location=device,
)
state_dict = {
key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
......@@ -323,7 +342,7 @@ waveglow = waveglow.to(device)
waveglow.eval()
with torch.no_grad():
waveforms = waveglow.infer(spec)
waveforms = waveglow.infer(spec)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment