Unverified Commit 0ac196c6 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Fix style checks in examples/tutorials (#2006)

parent e6dd7fd3
...@@ -22,16 +22,17 @@ print(torchaudio.__version__) ...@@ -22,16 +22,17 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#@markdown # @markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0. # @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions. # Preparation of data and helper functions.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import math import math
import os import os
...@@ -46,125 +47,135 @@ _SAMPLE_DIR = "_assets" ...@@ -46,125 +47,135 @@ _SAMPLE_DIR = "_assets"
SAMPLE_WAV_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.wav" SAMPLE_WAV_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.wav"
SAMPLE_WAV_PATH = os.path.join(_SAMPLE_DIR, "steam.wav") SAMPLE_WAV_PATH = os.path.join(_SAMPLE_DIR, "steam.wav")
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav" SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav" # noqa: E501
SAMPLE_RIR_PATH = os.path.join(_SAMPLE_DIR, "rir.wav") SAMPLE_RIR_PATH = os.path.join(_SAMPLE_DIR, "rir.wav")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav") SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav" SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav" # noqa: E501
SAMPLE_NOISE_PATH = os.path.join(_SAMPLE_DIR, "bg.wav") SAMPLE_NOISE_PATH = os.path.join(_SAMPLE_DIR, "bg.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True) os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data(): def _fetch_data():
uri = [ uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH), (SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_RIR_URL, SAMPLE_RIR_PATH), (SAMPLE_RIR_URL, SAMPLE_RIR_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH), (SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_NOISE_URL, SAMPLE_NOISE_PATH), (SAMPLE_NOISE_URL, SAMPLE_NOISE_PATH),
] ]
for url, path in uri: for url, path in uri:
with open(path, 'wb') as file_: with open(path, "wb") as file_:
file_.write(requests.get(url).content) file_.write(requests.get(url).content)
_fetch_data() _fetch_data()
def _get_sample(path, resample=None): def _get_sample(path, resample=None):
effects = [ effects = [["remix", "1"]]
["remix", "1"] if resample:
] effects.extend(
if resample: [
effects.extend([ ["lowpass", f"{resample // 2}"],
["lowpass", f"{resample // 2}"], ["rate", f"{resample}"],
["rate", f'{resample}'], ]
]) )
return torchaudio.sox_effects.apply_effects_file(path, effects=effects) return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None): def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample) return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def get_speech_sample(*, resample=None): def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample) return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None): def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1) figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1: if num_channels == 1:
axes = [axes] axes = [axes]
for c in range(num_channels): for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1) axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True) axes[c].grid(True)
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}') axes[c].set_ylabel(f"Channel {c+1}")
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
if ylim: if ylim:
axes[c].set_ylim(ylim) axes[c].set_ylim(ylim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) plt.show(block=False)
def print_stats(waveform, sample_rate=None, src=None): def print_stats(waveform, sample_rate=None, src=None):
if src: if src:
print("-" * 10) print("-" * 10)
print("Source:", src) print("Source:", src)
print("-" * 10) print("-" * 10)
if sample_rate: if sample_rate:
print("Sample Rate:", sample_rate) print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape)) print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype) print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}") print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}") print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}") print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}") print(f" - Std Dev: {waveform.std().item():6.3f}")
print() print()
print(waveform) print(waveform)
print() print()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
figure, axes = plt.subplots(num_channels, 1) if num_channels == 1:
if num_channels == 1: axes = [axes]
axes = [axes] for c in range(num_channels):
for c in range(num_channels): axes[c].specgram(waveform[c], Fs=sample_rate)
axes[c].specgram(waveform[c], Fs=sample_rate) if num_channels > 1:
if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}")
axes[c].set_ylabel(f'Channel {c+1}') if xlim:
if xlim: axes[c].set_xlim(xlim)
axes[c].set_xlim(xlim) figure.suptitle(title)
figure.suptitle(title) plt.show(block=False)
plt.show(block=False)
def play_audio(waveform, sample_rate): def play_audio(waveform, sample_rate):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def get_rir_sample(*, resample=None, processed=False): def get_rir_sample(*, resample=None, processed=False):
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample) rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed: if not processed:
return rir_raw, sample_rate return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)] rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2) rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1]) rir = torch.flip(rir, [1])
return rir, sample_rate return rir, sample_rate
def get_noise_sample(*, resample=None): def get_noise_sample(*, resample=None):
return _get_sample(SAMPLE_NOISE_PATH, resample=resample) return _get_sample(SAMPLE_NOISE_PATH, resample=resample)
###################################################################### ######################################################################
...@@ -208,20 +219,21 @@ waveform1, sample_rate1 = get_sample(resample=16000) ...@@ -208,20 +219,21 @@ waveform1, sample_rate1 = get_sample(resample=16000)
# Define effects # Define effects
effects = [ effects = [
["lowpass", "-1", "300"], # apply single-pole lowpass filter ["lowpass", "-1", "300"], # apply single-pole lowpass filter
["speed", "0.8"], # reduce the speed ["speed", "0.8"], # reduce the speed
# This only changes sample rate, so it is necessary to # This only changes sample rate, so it is necessary to
# add `rate` effect with original sample rate after this. # add `rate` effect with original sample rate after this.
["rate", f"{sample_rate1}"], ["rate", f"{sample_rate1}"],
["reverb", "-w"], # Reverbration gives some dramatic feeling ["reverb", "-w"], # Reverbration gives some dramatic feeling
] ]
# Apply effects # Apply effects
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor( waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
waveform1, sample_rate1, effects) waveform1, sample_rate1, effects
)
plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-.1, 3.2)) plot_waveform(waveform1, sample_rate1, title="Original", xlim=(-0.1, 3.2))
plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-.1, 3.2)) plot_waveform(waveform2, sample_rate2, title="Effects Applied", xlim=(-0.1, 3.2))
print_stats(waveform1, sample_rate=sample_rate1, src="Original") print_stats(waveform1, sample_rate=sample_rate1, src="Original")
print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied") print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied")
...@@ -268,7 +280,7 @@ play_audio(rir_raw, sample_rate) ...@@ -268,7 +280,7 @@ play_audio(rir_raw, sample_rate)
# the signal power, then flip along the time axis. # the signal power, then flip along the time axis.
# #
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)] rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2) rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1]) rir = torch.flip(rir, [1])
...@@ -281,7 +293,7 @@ plot_waveform(rir, sample_rate, title="Room Impulse Response", ylim=None) ...@@ -281,7 +293,7 @@ plot_waveform(rir, sample_rate, title="Room Impulse Response", ylim=None)
speech, _ = get_speech_sample(resample=sample_rate) speech, _ = get_speech_sample(resample=sample_rate)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0)) speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
augmented = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0] augmented = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_waveform(speech, sample_rate, title="Original", ylim=None) plot_waveform(speech, sample_rate, title="Original", ylim=None)
...@@ -312,7 +324,7 @@ play_audio(augmented, sample_rate) ...@@ -312,7 +324,7 @@ play_audio(augmented, sample_rate)
sample_rate = 8000 sample_rate = 8000
speech, _ = get_speech_sample(resample=sample_rate) speech, _ = get_speech_sample(resample=sample_rate)
noise, _ = get_noise_sample(resample=sample_rate) noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]] noise = noise[:, : speech.shape[1]]
plot_waveform(noise, sample_rate, title="Background noise") plot_waveform(noise, sample_rate, title="Background noise")
plot_specgram(noise, sample_rate, title="Background noise") plot_specgram(noise, sample_rate, title="Background noise")
...@@ -322,13 +334,13 @@ speech_power = speech.norm(p=2) ...@@ -322,13 +334,13 @@ speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2) noise_power = noise.norm(p=2)
for snr_db in [20, 10, 3]: for snr_db in [20, 10, 3]:
snr = math.exp(snr_db / 10) snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power scale = snr * noise_power / speech_power
noisy_speech = (scale * speech + noise) / 2 noisy_speech = (scale * speech + noise) / 2
plot_waveform(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]") plot_waveform(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
plot_specgram(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]") plot_specgram(noisy_speech, sample_rate, title=f"SNR: {snr_db} [dB]")
play_audio(noisy_speech, sample_rate) play_audio(noisy_speech, sample_rate)
###################################################################### ######################################################################
# Applying codec to Tensor object # Applying codec to Tensor object
...@@ -346,15 +358,15 @@ plot_specgram(waveform, sample_rate, title="Original") ...@@ -346,15 +358,15 @@ plot_specgram(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate) play_audio(waveform, sample_rate)
configs = [ configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"), ({"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "gsm"}, "GSM-FR"), ({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"), ({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis"), ({"format": "vorbis", "compression": -1}, "Vorbis"),
] ]
for param, title in configs: for param, title in configs:
augmented = F.apply_codec(waveform, sample_rate, **param) augmented = F.apply_codec(waveform, sample_rate, **param)
plot_specgram(augmented, sample_rate, title=title) plot_specgram(augmented, sample_rate, title=title)
play_audio(augmented, sample_rate) play_audio(augmented, sample_rate)
###################################################################### ######################################################################
# Simulating a phone recoding # Simulating a phone recoding
...@@ -373,7 +385,7 @@ play_audio(speech, sample_rate) ...@@ -373,7 +385,7 @@ play_audio(speech, sample_rate)
# Apply RIR # Apply RIR
rir, _ = get_rir_sample(resample=sample_rate, processed=True) rir, _ = get_rir_sample(resample=sample_rate, processed=True)
speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0)) speech_ = torch.nn.functional.pad(speech, (rir.shape[1] - 1, 0))
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0] speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
plot_specgram(speech, sample_rate, title="RIR Applied") plot_specgram(speech, sample_rate, title="RIR Applied")
...@@ -384,7 +396,7 @@ play_audio(speech, sample_rate) ...@@ -384,7 +396,7 @@ play_audio(speech, sample_rate)
# the noise contains the acoustic feature of the environment. Therefore, we add # the noise contains the acoustic feature of the environment. Therefore, we add
# the noise after RIR application. # the noise after RIR application.
noise, _ = get_noise_sample(resample=sample_rate) noise, _ = get_noise_sample(resample=sample_rate)
noise = noise[:, :speech.shape[1]] noise = noise[:, : speech.shape[1]]
snr_db = 8 snr_db = 8
scale = math.exp(snr_db / 10) * noise.norm(p=2) / speech.norm(p=2) scale = math.exp(snr_db / 10) * noise.norm(p=2) / speech.norm(p=2)
...@@ -395,13 +407,20 @@ play_audio(speech, sample_rate) ...@@ -395,13 +407,20 @@ play_audio(speech, sample_rate)
# Apply filtering and change sample rate # Apply filtering and change sample rate
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor( speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech, speech,
sample_rate, sample_rate,
effects=[ effects=[
["lowpass", "4000"], ["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"], [
["rate", "8000"], "compand",
], "0.02,0.05",
"-60,-60,-30,-10,-20,-8,-5,-8,-2,-8",
"-8",
"-7",
"0.05",
],
["rate", "8000"],
],
) )
plot_specgram(speech, sample_rate, title="Filtered") plot_specgram(speech, sample_rate, title="Filtered")
......
...@@ -23,14 +23,14 @@ print(torchaudio.__version__) ...@@ -23,14 +23,14 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
# Preparation of data and helper functions. # Preparation of data and helper functions.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import multiprocessing import multiprocessing
import os import os
...@@ -42,52 +42,57 @@ _SAMPLE_DIR = "_assets" ...@@ -42,52 +42,57 @@ _SAMPLE_DIR = "_assets"
YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no") YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True) os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def _download_yesno(): def _download_yesno():
if os.path.exists(os.path.join(YESNO_DATASET_PATH, "waves_yesno.tar.gz")): if os.path.exists(os.path.join(YESNO_DATASET_PATH, "waves_yesno.tar.gz")):
return return
torchaudio.datasets.YESNO(root=YESNO_DATASET_PATH, download=True) torchaudio.datasets.YESNO(root=YESNO_DATASET_PATH, download=True)
YESNO_DOWNLOAD_PROCESS = multiprocessing.Process(target=_download_yesno) YESNO_DOWNLOAD_PROCESS = multiprocessing.Process(target=_download_yesno)
YESNO_DOWNLOAD_PROCESS.start() YESNO_DOWNLOAD_PROCESS.start()
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
figure, axes = plt.subplots(num_channels, 1) if num_channels == 1:
if num_channels == 1: axes = [axes]
axes = [axes] for c in range(num_channels):
for c in range(num_channels): axes[c].specgram(waveform[c], Fs=sample_rate)
axes[c].specgram(waveform[c], Fs=sample_rate) if num_channels > 1:
if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}")
axes[c].set_ylabel(f'Channel {c+1}') if xlim:
if xlim: axes[c].set_xlim(xlim)
axes[c].set_xlim(xlim) figure.suptitle(title)
figure.suptitle(title) plt.show(block=False)
plt.show(block=False)
def play_audio(waveform, sample_rate): def play_audio(waveform, sample_rate):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
###################################################################### ######################################################################
# Here, we show how to use the ``YESNO`` dataset. # Here, we show how to use the ``YESNO`` dataset.
# #
YESNO_DOWNLOAD_PROCESS.join() YESNO_DOWNLOAD_PROCESS.join()
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]: for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i] waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate) play_audio(waveform, sample_rate)
...@@ -20,16 +20,17 @@ print(torchaudio.__version__) ...@@ -20,16 +20,17 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#@markdown # @markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0. # @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions. # Preparation of data and helper functions.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import os import os
import requests import requests
...@@ -40,62 +41,69 @@ import matplotlib.pyplot as plt ...@@ -40,62 +41,69 @@ import matplotlib.pyplot as plt
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav") SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True) os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data(): def _fetch_data():
uri = [ uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH), (SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
] ]
for url, path in uri: for url, path in uri:
with open(path, 'wb') as file_: with open(path, "wb") as file_:
file_.write(requests.get(url).content) file_.write(requests.get(url).content)
_fetch_data() _fetch_data()
def _get_sample(path, resample=None): def _get_sample(path, resample=None):
effects = [ effects = [["remix", "1"]]
["remix", "1"] if resample:
] effects.extend(
if resample: [
effects.extend([ ["lowpass", f"{resample // 2}"],
["lowpass", f"{resample // 2}"], ["rate", f"{resample}"],
["rate", f'{resample}'], ]
]) )
return torchaudio.sox_effects.apply_effects_file(path, effects=effects) return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None): def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample) return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def get_spectrogram( def get_spectrogram(
n_fft = 400, n_fft=400,
win_len = None, win_len=None,
hop_len = None, hop_len=None,
power = 2.0, power=2.0,
): ):
waveform, _ = get_speech_sample() waveform, _ = get_speech_sample()
spectrogram = T.Spectrogram( spectrogram = T.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
win_length=win_len, win_length=win_len,
hop_length=hop_len, hop_length=hop_len,
center=True, center=True,
pad_mode="reflect", pad_mode="reflect",
power=power, power=power,
) )
return spectrogram(waveform) return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
fig, axs = plt.subplots(1, 1) def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
axs.set_title(title or 'Spectrogram (db)') fig, axs = plt.subplots(1, 1)
axs.set_ylabel(ylabel) axs.set_title(title or "Spectrogram (db)")
axs.set_xlabel('frame') axs.set_ylabel(ylabel)
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect) axs.set_xlabel("frame")
if xmax: im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
axs.set_xlim((0, xmax)) if xmax:
fig.colorbar(im, ax=axs) axs.set_xlim((0, xmax))
plt.show(block=False) fig.colorbar(im, ax=axs)
plt.show(block=False)
###################################################################### ######################################################################
# SpecAugment # SpecAugment
...@@ -111,18 +119,23 @@ def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=No ...@@ -111,18 +119,23 @@ def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=No
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# #
spec = get_spectrogram(power=None) spec = get_spectrogram(power=None)
stretch = T.TimeStretch() stretch = T.TimeStretch()
rate = 1.2 rate = 1.2
spec_ = stretch(spec, rate) spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304) plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect='equal', xmax=304) plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9 rate = 0.9
spec_ = stretch(spec, rate) spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304) plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
###################################################################### ######################################################################
# TimeMasking # TimeMasking
......
...@@ -38,16 +38,17 @@ print(torchaudio.__version__) ...@@ -38,16 +38,17 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#@markdown # @markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0. # @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# Preparation of data and helper functions. # Preparation of data and helper functions.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import os import os
import requests import requests
...@@ -59,143 +60,154 @@ from IPython.display import Audio, display ...@@ -59,143 +60,154 @@ from IPython.display import Audio, display
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav") SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True) os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data(): def _fetch_data():
uri = [ uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH), (SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
] ]
for url, path in uri: for url, path in uri:
with open(path, 'wb') as file_: with open(path, "wb") as file_:
file_.write(requests.get(url).content) file_.write(requests.get(url).content)
_fetch_data() _fetch_data()
def _get_sample(path, resample=None): def _get_sample(path, resample=None):
effects = [ effects = [["remix", "1"]]
["remix", "1"] if resample:
] effects.extend(
if resample: [
effects.extend([ ["lowpass", f"{resample // 2}"],
["lowpass", f"{resample // 2}"], ["rate", f"{resample}"],
["rate", f'{resample}'], ]
]) )
return torchaudio.sox_effects.apply_effects_file(path, effects=effects) return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None): def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample) return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def print_stats(waveform, sample_rate=None, src=None): def print_stats(waveform, sample_rate=None, src=None):
if src: if src:
print("-" * 10) print("-" * 10)
print("Source:", src) print("Source:", src)
print("-" * 10) print("-" * 10)
if sample_rate: if sample_rate:
print("Sample Rate:", sample_rate) print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape)) print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype) print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}") print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}") print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}") print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}") print(f" - Std Dev: {waveform.std().item():6.3f}")
print() print()
print(waveform) print(waveform)
print() print()
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
fig, axs = plt.subplots(1, 1) def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
axs.set_title(title or 'Spectrogram (db)') fig, axs = plt.subplots(1, 1)
axs.set_ylabel(ylabel) axs.set_title(title or "Spectrogram (db)")
axs.set_xlabel('frame') axs.set_ylabel(ylabel)
im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect) axs.set_xlabel("frame")
if xmax: im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
axs.set_xlim((0, xmax)) if xmax:
fig.colorbar(im, ax=axs) axs.set_xlim((0, xmax))
plt.show(block=False) fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None): def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1) figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1: if num_channels == 1:
axes = [axes] axes = [axes]
for c in range(num_channels): for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1) axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True) axes[c].grid(True)
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}') axes[c].set_ylabel(f"Channel {c+1}")
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
if ylim: if ylim:
axes[c].set_ylim(ylim) axes[c].set_ylim(ylim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) plt.show(block=False)
def play_audio(waveform, sample_rate): def play_audio(waveform, sample_rate):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_mel_fbank(fbank, title=None): def plot_mel_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1) fig, axs = plt.subplots(1, 1)
axs.set_title(title or 'Filter bank') axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect='auto') axs.imshow(fbank, aspect="auto")
axs.set_ylabel('frequency bin') axs.set_ylabel("frequency bin")
axs.set_xlabel('mel bin') axs.set_xlabel("mel bin")
plt.show(block=False) plt.show(block=False)
def plot_pitch(waveform, sample_rate, pitch): def plot_pitch(waveform, sample_rate, pitch):
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature") axis.set_title("Pitch Feature")
axis.grid(True) axis.grid(True)
end_time = waveform.shape[1] / sample_rate end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1]) time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3) axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx() axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1]) time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln2 = axis2.plot( axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
time_axis, pitch[0], linewidth=2, label='Pitch', color='green')
axis2.legend(loc=0)
plt.show(block=False)
axis2.legend(loc=0)
plt.show(block=False)
def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc): def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature") axis.set_title("Kaldi Pitch Feature")
axis.grid(True) axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
end_time = waveform.shape[1] / sample_rate time_axis = torch.linspace(0, end_time, pitch.shape[1])
time_axis = torch.linspace(0, end_time, waveform.shape[1]) ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.plot(time_axis, waveform[0], linewidth=1, color='gray', alpha=0.3) axis.set_ylim((-1.3, 1.3))
time_axis = torch.linspace(0, end_time, pitch.shape[1]) axis2 = axis.twinx()
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label='Pitch', color='green') time_axis = torch.linspace(0, end_time, nfcc.shape[1])
axis.set_ylim((-1.3, 1.3)) ln2 = axis2.plot(
time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--"
)
axis2 = axis.twinx() lns = ln1 + ln2
time_axis = torch.linspace(0, end_time, nfcc.shape[1]) labels = [l.get_label() for l in lns]
ln2 = axis2.plot( axis.legend(lns, labels, loc=0)
time_axis, nfcc[0], linewidth=2, label='NFCC', color='blue', linestyle='--') plt.show(block=False)
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
###################################################################### ######################################################################
# Spectrogram # Spectrogram
...@@ -206,7 +218,6 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc): ...@@ -206,7 +218,6 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
# #
waveform, sample_rate = get_speech_sample() waveform, sample_rate = get_speech_sample()
n_fft = 1024 n_fft = 1024
...@@ -226,7 +237,7 @@ spectrogram = T.Spectrogram( ...@@ -226,7 +237,7 @@ spectrogram = T.Spectrogram(
spec = spectrogram(waveform) spec = spectrogram(waveform)
print_stats(spec) print_stats(spec)
plot_spectrogram(spec[0], title='torchaudio') plot_spectrogram(spec[0], title="torchaudio")
###################################################################### ######################################################################
# GriffinLim # GriffinLim
...@@ -280,10 +291,10 @@ sample_rate = 6000 ...@@ -280,10 +291,10 @@ sample_rate = 6000
mel_filters = F.melscale_fbanks( mel_filters = F.melscale_fbanks(
int(n_fft // 2 + 1), int(n_fft // 2 + 1),
n_mels=n_mels, n_mels=n_mels,
f_min=0., f_min=0.0,
f_max=sample_rate/2., f_max=sample_rate / 2.0,
sample_rate=sample_rate, sample_rate=sample_rate,
norm='slaney' norm="slaney",
) )
plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio") plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
...@@ -300,16 +311,16 @@ mel_filters_librosa = librosa.filters.mel( ...@@ -300,16 +311,16 @@ mel_filters_librosa = librosa.filters.mel(
sample_rate, sample_rate,
n_fft, n_fft,
n_mels=n_mels, n_mels=n_mels,
fmin=0., fmin=0.0,
fmax=sample_rate/2., fmax=sample_rate / 2.0,
norm='slaney', norm="slaney",
htk=True, htk=True,
).T ).T
plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa") plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
mse = torch.square(mel_filters - mel_filters_librosa).mean().item() mse = torch.square(mel_filters - mel_filters_librosa).mean().item()
print('Mean Square Difference: ', mse) print("Mean Square Difference: ", mse)
###################################################################### ######################################################################
# MelSpectrogram # MelSpectrogram
...@@ -336,15 +347,14 @@ mel_spectrogram = T.MelSpectrogram( ...@@ -336,15 +347,14 @@ mel_spectrogram = T.MelSpectrogram(
center=True, center=True,
pad_mode="reflect", pad_mode="reflect",
power=2.0, power=2.0,
norm='slaney', norm="slaney",
onesided=True, onesided=True,
n_mels=n_mels, n_mels=n_mels,
mel_scale="htk", mel_scale="htk",
) )
melspec = mel_spectrogram(waveform) melspec = mel_spectrogram(waveform)
plot_spectrogram( plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq")
melspec[0], title="MelSpectrogram - torchaudio", ylabel='mel freq')
###################################################################### ######################################################################
# Comparison against librosa # Comparison against librosa
...@@ -365,14 +375,13 @@ melspec_librosa = librosa.feature.melspectrogram( ...@@ -365,14 +375,13 @@ melspec_librosa = librosa.feature.melspectrogram(
pad_mode="reflect", pad_mode="reflect",
power=2.0, power=2.0,
n_mels=n_mels, n_mels=n_mels,
norm='slaney', norm="slaney",
htk=True, htk=True,
) )
plot_spectrogram( plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq")
melspec_librosa, title="MelSpectrogram - librosa", ylabel='mel freq')
mse = torch.square(melspec - melspec_librosa).mean().item() mse = torch.square(melspec - melspec_librosa).mean().item()
print('Mean Square Difference: ', mse) print("Mean Square Difference: ", mse)
###################################################################### ######################################################################
# MFCC # MFCC
...@@ -391,11 +400,11 @@ mfcc_transform = T.MFCC( ...@@ -391,11 +400,11 @@ mfcc_transform = T.MFCC(
sample_rate=sample_rate, sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
melkwargs={ melkwargs={
'n_fft': n_fft, "n_fft": n_fft,
'n_mels': n_mels, "n_mels": n_mels,
'hop_length': hop_length, "hop_length": hop_length,
'mel_scale': 'htk', "mel_scale": "htk",
} },
) )
mfcc = mfcc_transform(waveform) mfcc = mfcc_transform(waveform)
...@@ -409,18 +418,27 @@ plot_spectrogram(mfcc[0]) ...@@ -409,18 +418,27 @@ plot_spectrogram(mfcc[0])
melspec = librosa.feature.melspectrogram( melspec = librosa.feature.melspectrogram(
y=waveform.numpy()[0], sr=sample_rate, n_fft=n_fft, y=waveform.numpy()[0],
win_length=win_length, hop_length=hop_length, sr=sample_rate,
n_mels=n_mels, htk=True, norm=None) n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
)
mfcc_librosa = librosa.feature.mfcc( mfcc_librosa = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec), S=librosa.core.spectrum.power_to_db(melspec),
n_mfcc=n_mfcc, dct_type=2, norm='ortho') n_mfcc=n_mfcc,
dct_type=2,
norm="ortho",
)
plot_spectrogram(mfcc_librosa) plot_spectrogram(mfcc_librosa)
mse = torch.square(mfcc - mfcc_librosa).mean().item() mse = torch.square(mfcc - mfcc_librosa).mean().item()
print('Mean Square Difference: ', mse) print("Mean Square Difference: ", mse)
###################################################################### ######################################################################
# Pitch # Pitch
......
...@@ -21,12 +21,13 @@ print(torchaudio.__version__) ...@@ -21,12 +21,13 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#@markdown # @markdown
#@markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), which is licensed under Creative Commos BY 4.0. # @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
import io import io
...@@ -51,7 +52,7 @@ SAMPLE_MP3_PATH = os.path.join(_SAMPLE_DIR, "steam.mp3") ...@@ -51,7 +52,7 @@ SAMPLE_MP3_PATH = os.path.join(_SAMPLE_DIR, "steam.mp3")
SAMPLE_GSM_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.gsm" SAMPLE_GSM_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/steam-train-whistle-daniel_simon.gsm"
SAMPLE_GSM_PATH = os.path.join(_SAMPLE_DIR, "steam.gsm") SAMPLE_GSM_PATH = os.path.join(_SAMPLE_DIR, "steam.gsm")
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav") SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
SAMPLE_TAR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit.tar.gz" SAMPLE_TAR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit.tar.gz"
...@@ -61,117 +62,127 @@ SAMPLE_TAR_ITEM = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp ...@@ -61,117 +62,127 @@ SAMPLE_TAR_ITEM = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp
S3_BUCKET = "pytorch-tutorial-assets" S3_BUCKET = "pytorch-tutorial-assets"
S3_KEY = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" S3_KEY = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
def _fetch_data(): def _fetch_data():
os.makedirs(_SAMPLE_DIR, exist_ok=True) os.makedirs(_SAMPLE_DIR, exist_ok=True)
uri = [ uri = [
(SAMPLE_WAV_URL, SAMPLE_WAV_PATH), (SAMPLE_WAV_URL, SAMPLE_WAV_PATH),
(SAMPLE_MP3_URL, SAMPLE_MP3_PATH), (SAMPLE_MP3_URL, SAMPLE_MP3_PATH),
(SAMPLE_GSM_URL, SAMPLE_GSM_PATH), (SAMPLE_GSM_URL, SAMPLE_GSM_PATH),
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH), (SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
(SAMPLE_TAR_URL, SAMPLE_TAR_PATH), (SAMPLE_TAR_URL, SAMPLE_TAR_PATH),
] ]
for url, path in uri: for url, path in uri:
with open(path, 'wb') as file_: with open(path, "wb") as file_:
file_.write(requests.get(url).content) file_.write(requests.get(url).content)
_fetch_data() _fetch_data()
def print_stats(waveform, sample_rate=None, src=None): def print_stats(waveform, sample_rate=None, src=None):
if src: if src:
print("-" * 10) print("-" * 10)
print("Source:", src) print("Source:", src)
print("-" * 10) print("-" * 10)
if sample_rate: if sample_rate:
print("Sample Rate:", sample_rate) print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape)) print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype) print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}") print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}") print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}") print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}") print(f" - Std Dev: {waveform.std().item():6.3f}")
print() print()
print(waveform) print(waveform)
print() print()
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None): def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1) figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1: if num_channels == 1:
axes = [axes] axes = [axes]
for c in range(num_channels): for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1) axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True) axes[c].grid(True)
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f'Channel {c+1}') axes[c].set_ylabel(f"Channel {c+1}")
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
if ylim: if ylim:
axes[c].set_ylim(ylim) axes[c].set_ylim(ylim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) plt.show(block=False)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
figure, axes = plt.subplots(num_channels, 1) if num_channels == 1:
if num_channels == 1: axes = [axes]
axes = [axes] for c in range(num_channels):
for c in range(num_channels): axes[c].specgram(waveform[c], Fs=sample_rate)
axes[c].specgram(waveform[c], Fs=sample_rate) if num_channels > 1:
if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}")
axes[c].set_ylabel(f'Channel {c+1}') if xlim:
if xlim: axes[c].set_xlim(xlim)
axes[c].set_xlim(xlim) figure.suptitle(title)
figure.suptitle(title) plt.show(block=False)
plt.show(block=False)
def play_audio(waveform, sample_rate): def play_audio(waveform, sample_rate):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def _get_sample(path, resample=None): def _get_sample(path, resample=None):
effects = [ effects = [["remix", "1"]]
["remix", "1"] if resample:
] effects.extend(
if resample: [
effects.extend([ ["lowpass", f"{resample // 2}"],
["lowpass", f"{resample // 2}"], ["rate", f"{resample}"],
["rate", f'{resample}'], ]
]) )
return torchaudio.sox_effects.apply_effects_file(path, effects=effects) return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_sample(*, resample=None): def get_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_PATH, resample=resample) return _get_sample(SAMPLE_WAV_PATH, resample=resample)
def inspect_file(path): def inspect_file(path):
print("-" * 10) print("-" * 10)
print("Source:", path) print("Source:", path)
print("-" * 10) print("-" * 10)
print(f" - File size: {os.path.getsize(path)} bytes") print(f" - File size: {os.path.getsize(path)} bytes")
print(f" - {torchaudio.info(path)}") print(f" - {torchaudio.info(path)}")
###################################################################### ######################################################################
# Quering audio metadata # Querying audio metadata
# ---------------------- # -----------------------
# #
# Function ``torchaudio.info`` fetches audio metadata. You can provide # Function ``torchaudio.info`` fetches audio metadata. You can provide
# a path-like object or file-like object. # a path-like object or file-like object.
# #
metadata = torchaudio.info(SAMPLE_WAV_PATH) metadata = torchaudio.info(SAMPLE_WAV_PATH)
print(metadata) print(metadata)
...@@ -231,7 +242,7 @@ print(metadata) ...@@ -231,7 +242,7 @@ print(metadata)
print("Source:", SAMPLE_WAV_URL) print("Source:", SAMPLE_WAV_URL)
with requests.get(SAMPLE_WAV_URL, stream=True) as response: with requests.get(SAMPLE_WAV_URL, stream=True) as response:
metadata = torchaudio.info(response.raw) metadata = torchaudio.info(response.raw)
print(metadata) print(metadata)
###################################################################### ######################################################################
...@@ -248,9 +259,9 @@ print(metadata) ...@@ -248,9 +259,9 @@ print(metadata)
print("Source:", SAMPLE_MP3_URL) print("Source:", SAMPLE_MP3_URL)
with requests.get(SAMPLE_MP3_URL, stream=True) as response: with requests.get(SAMPLE_MP3_URL, stream=True) as response:
metadata = torchaudio.info(response.raw, format="mp3") metadata = torchaudio.info(response.raw, format="mp3")
print(f"Fetched {response.raw.tell()} bytes.") print(f"Fetched {response.raw.tell()} bytes.")
print(metadata) print(metadata)
###################################################################### ######################################################################
...@@ -291,19 +302,19 @@ play_audio(waveform, sample_rate) ...@@ -291,19 +302,19 @@ play_audio(waveform, sample_rate)
# Load audio data as HTTP request # Load audio data as HTTP request
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response: with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform, sample_rate = torchaudio.load(response.raw) waveform, sample_rate = torchaudio.load(response.raw)
plot_specgram(waveform, sample_rate, title="HTTP datasource") plot_specgram(waveform, sample_rate, title="HTTP datasource")
# Load audio from tar file # Load audio from tar file
with tarfile.open(SAMPLE_TAR_PATH, mode='r') as tarfile_: with tarfile.open(SAMPLE_TAR_PATH, mode="r") as tarfile_:
fileobj = tarfile_.extractfile(SAMPLE_TAR_ITEM) fileobj = tarfile_.extractfile(SAMPLE_TAR_ITEM)
waveform, sample_rate = torchaudio.load(fileobj) waveform, sample_rate = torchaudio.load(fileobj)
plot_specgram(waveform, sample_rate, title="TAR file") plot_specgram(waveform, sample_rate, title="TAR file")
# Load audio from S3 # Load audio from S3
client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
response = client.get_object(Bucket=S3_BUCKET, Key=S3_KEY) response = client.get_object(Bucket=S3_BUCKET, Key=S3_KEY)
waveform, sample_rate = torchaudio.load(response['Body']) waveform, sample_rate = torchaudio.load(response["Body"])
plot_specgram(waveform, sample_rate, title="From S3") plot_specgram(waveform, sample_rate, title="From S3")
...@@ -336,15 +347,16 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds ...@@ -336,15 +347,16 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds
print("Fetching all the data...") print("Fetching all the data...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response: with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform1, sample_rate1 = torchaudio.load(response.raw) waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset:frame_offset+num_frames] waveform1 = waveform1[:, frame_offset: frame_offset + num_frames]
print(f" - Fetched {response.raw.tell()} bytes") print(f" - Fetched {response.raw.tell()} bytes")
print("Fetching until the requested frames are available...") print("Fetching until the requested frames are available...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response: with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load( waveform2, sample_rate2 = torchaudio.load(
response.raw, frame_offset=frame_offset, num_frames=num_frames) response.raw, frame_offset=frame_offset, num_frames=num_frames
print(f" - Fetched {response.raw.tell()} bytes") )
print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="") print("Checking the resulting waveform ... ", end="")
assert (waveform1 == waveform2).all() assert (waveform1 == waveform2).all()
...@@ -389,9 +401,7 @@ inspect_file(path) ...@@ -389,9 +401,7 @@ inspect_file(path)
# Save as 16-bit signed integer Linear PCM # Save as 16-bit signed integer Linear PCM
# The resulting file occupies half the storage but loses precision # The resulting file occupies half the storage but loses precision
path = f"{_SAMPLE_DIR}/save_example_PCM_S16.wav" path = f"{_SAMPLE_DIR}/save_example_PCM_S16.wav"
torchaudio.save( torchaudio.save(path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
path, waveform, sample_rate,
encoding="PCM_S", bits_per_sample=16)
inspect_file(path) inspect_file(path)
...@@ -402,19 +412,19 @@ inspect_file(path) ...@@ -402,19 +412,19 @@ inspect_file(path)
waveform, sample_rate = get_sample(resample=8000) waveform, sample_rate = get_sample(resample=8000)
formats = [ formats = [
"mp3", "mp3",
"flac", "flac",
"vorbis", "vorbis",
"sph", "sph",
"amb", "amb",
"amr-nb", "amr-nb",
"gsm", "gsm",
] ]
for format in formats: for format in formats:
path = f"{_SAMPLE_DIR}/save_example.{format}" path = f"{_SAMPLE_DIR}/save_example.{format}"
torchaudio.save(path, waveform, sample_rate, format=format) torchaudio.save(path, waveform, sample_rate, format=format)
inspect_file(path) inspect_file(path)
###################################################################### ######################################################################
...@@ -435,4 +445,3 @@ torchaudio.save(buffer_, waveform, sample_rate, format="wav") ...@@ -435,4 +445,3 @@ torchaudio.save(buffer_, waveform, sample_rate, format="wav")
buffer_.seek(0) buffer_.seek(0)
print(buffer_.read(16)) print(buffer_.read(16))
...@@ -24,14 +24,14 @@ print(torchaudio.__version__) ...@@ -24,14 +24,14 @@ print(torchaudio.__version__)
# -------------------------------------------------------- # --------------------------------------------------------
# #
#@title Prepare data and utility functions. {display-mode: "form"} # @title Prepare data and utility functions. {display-mode: "form"}
#@markdown # @markdown
#@markdown You do not need to look into this cell. # @markdown You do not need to look into this cell.
#@markdown Just execute once and you are good to go. # @markdown Just execute once and you are good to go.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
# Preparation of data and helper functions. # Preparation of data and helper functions.
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import math import math
import time import time
...@@ -46,95 +46,108 @@ DEFAULT_OFFSET = 201 ...@@ -46,95 +46,108 @@ DEFAULT_OFFSET = 201
SWEEP_MAX_SAMPLE_RATE = 48000 SWEEP_MAX_SAMPLE_RATE = 48000
DEFAULT_LOWPASS_FILTER_WIDTH = 6 DEFAULT_LOWPASS_FILTER_WIDTH = 6
DEFAULT_ROLLOFF = 0.99 DEFAULT_ROLLOFF = 0.99
DEFAULT_RESAMPLING_METHOD = 'sinc_interpolation' DEFAULT_RESAMPLING_METHOD = "sinc_interpolation"
def _get_log_freq(sample_rate, max_sweep_rate, offset): def _get_log_freq(sample_rate, max_sweep_rate, offset):
"""Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2] """Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
offset is used to avoid negative infinity `log(offset + x)`. offset is used to avoid negative infinity `log(offset + x)`.
"""
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return (
torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
)
"""
half = sample_rate // 2
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
def _get_inverse_log_freq(freq, sample_rate, offset): def _get_inverse_log_freq(freq, sample_rate, offset):
"""Find the time where the given frequency is given by _get_log_freq""" """Find the time where the given frequency is given by _get_log_freq"""
half = sample_rate // 2 half = sample_rate // 2
return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset)) return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))
def _get_freq_ticks(sample_rate, offset, f_max): def _get_freq_ticks(sample_rate, offset, f_max):
# Given the original sample rate used for generating the sweep, # Given the original sample rate used for generating the sweep,
# find the x-axis value where the log-scale major frequency values fall in # find the x-axis value where the log-scale major frequency values fall in
time, freq = [], [] time, freq = [], []
for exp in range(2, 5): for exp in range(2, 5):
for v in range(1, 10): for v in range(1, 10):
f = v * 10 ** exp f = v * 10 ** exp
if f < sample_rate // 2: if f < sample_rate // 2:
t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
time.append(t) time.append(t)
freq.append(f) freq.append(f)
t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
time.append(t_max) time.append(t_max)
freq.append(f_max) freq.append(f_max)
return time, freq return time, freq
def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET): def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
max_sweep_rate = sample_rate max_sweep_rate = sample_rate
freq = _get_log_freq(sample_rate, max_sweep_rate, offset) freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
delta = 2 * math.pi * freq / sample_rate delta = 2 * math.pi * freq / sample_rate
cummulative = torch.cumsum(delta, dim=0) cummulative = torch.cumsum(delta, dim=0)
signal = torch.sin(cummulative).unsqueeze(dim=0) signal = torch.sin(cummulative).unsqueeze(dim=0)
return signal return signal
def plot_sweep(waveform, sample_rate, title, max_sweep_rate=SWEEP_MAX_SAMPLE_RATE, offset=DEFAULT_OFFSET):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2] def plot_sweep(
y_ticks = [1000, 5000, 10000, 20000, sample_rate//2] waveform,
sample_rate,
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2) title,
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq] max_sweep_rate=SWEEP_MAX_SAMPLE_RATE,
freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2] offset=DEFAULT_OFFSET,
):
figure, axis = plt.subplots(1, 1) x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
axis.specgram(waveform[0].numpy(), Fs=sample_rate) y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y) time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
axis.set_xlabel('Original Signal Frequency (Hz, log scale)') freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
axis.set_ylabel('Waveform Frequency (Hz)') freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2]
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67) figure, axis = plt.subplots(1, 1)
figure.suptitle(f'{title} (sample rate: {sample_rate} Hz)') axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.show(block=True) plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
axis.set_ylabel("Waveform Frequency (Hz)")
axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.show(block=True)
def play_audio(waveform, sample_rate): def play_audio(waveform, sample_rate):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
figure, axes = plt.subplots(num_channels, 1) if num_channels == 1:
if num_channels == 1: axes = [axes]
axes = [axes] for c in range(num_channels):
for c in range(num_channels): axes[c].specgram(waveform[c], Fs=sample_rate)
axes[c].specgram(waveform[c], Fs=sample_rate) if num_channels > 1:
if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}")
axes[c].set_ylabel(f'Channel {c+1}') if xlim:
if xlim: axes[c].set_xlim(xlim)
axes[c].set_xlim(xlim) figure.suptitle(title)
figure.suptitle(title) plt.show(block=False)
plt.show(block=False)
def benchmark_resample( def benchmark_resample(
method, method,
...@@ -146,30 +159,45 @@ def benchmark_resample( ...@@ -146,30 +159,45 @@ def benchmark_resample(
resampling_method=DEFAULT_RESAMPLING_METHOD, resampling_method=DEFAULT_RESAMPLING_METHOD,
beta=None, beta=None,
librosa_type=None, librosa_type=None,
iters=5 iters=5,
): ):
if method == "functional": if method == "functional":
begin = time.time() begin = time.time()
for _ in range(iters): for _ in range(iters):
F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width, F.resample(
rolloff=rolloff, resampling_method=resampling_method) waveform,
elapsed = time.time() - begin sample_rate,
return elapsed / iters resample_rate,
elif method == "transforms": lowpass_filter_width=lowpass_filter_width,
resampler = T.Resample(sample_rate, resample_rate, lowpass_filter_width=lowpass_filter_width, rolloff=rolloff,
rolloff=rolloff, resampling_method=resampling_method, dtype=waveform.dtype) resampling_method=resampling_method,
begin = time.time() )
for _ in range(iters): elapsed = time.time() - begin
resampler(waveform) return elapsed / iters
elapsed = time.time() - begin elif method == "transforms":
return elapsed / iters resampler = T.Resample(
elif method == "librosa": sample_rate,
waveform_np = waveform.squeeze().numpy() resample_rate,
begin = time.time() lowpass_filter_width=lowpass_filter_width,
for _ in range(iters): rolloff=rolloff,
librosa.resample(waveform_np, sample_rate, resample_rate, res_type=librosa_type) resampling_method=resampling_method,
elapsed = time.time() - begin dtype=waveform.dtype,
return elapsed / iters )
begin = time.time()
for _ in range(iters):
resampler(waveform)
elapsed = time.time() - begin
return elapsed / iters
elif method == "librosa":
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(
waveform_np, sample_rate, resample_rate, res_type=librosa_type
)
elapsed = time.time() - begin
return elapsed / iters
###################################################################### ######################################################################
# To resample an audio waveform from one freqeuncy to another, you can use # To resample an audio waveform from one freqeuncy to another, you can use
...@@ -202,6 +230,7 @@ def benchmark_resample( ...@@ -202,6 +230,7 @@ def benchmark_resample(
# plotted waveform, and color intensity the amplitude. # plotted waveform, and color intensity the amplitude.
# #
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
...@@ -235,10 +264,14 @@ play_audio(waveform, sample_rate) ...@@ -235,10 +264,14 @@ play_audio(waveform, sample_rate)
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6) resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=6
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128) resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=128
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
...@@ -282,10 +315,14 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") ...@@ -282,10 +315,14 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation") resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation"
)
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window") resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="kaiser_window"
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
...@@ -301,7 +338,7 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") ...@@ -301,7 +338,7 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
### kaiser_best # kaiser_best
resampled_waveform = F.resample( resampled_waveform = F.resample(
waveform, waveform,
sample_rate, sample_rate,
...@@ -309,18 +346,23 @@ resampled_waveform = F.resample( ...@@ -309,18 +346,23 @@ resampled_waveform = F.resample(
lowpass_filter_width=64, lowpass_filter_width=64,
rolloff=0.9475937167399596, rolloff=0.9475937167399596,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=14.769656459379492 beta=14.769656459379492,
) )
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_best')).unsqueeze(0) librosa.resample(
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)") waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse) print("torchaudio and librosa kaiser best MSE:", mse)
### kaiser_fast # kaiser_fast
resampled_waveform = F.resample( resampled_waveform = F.resample(
waveform, waveform,
sample_rate, sample_rate,
...@@ -328,13 +370,20 @@ resampled_waveform = F.resample( ...@@ -328,13 +370,20 @@ resampled_waveform = F.resample(
lowpass_filter_width=16, lowpass_filter_width=16,
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386 beta=8.555504641634386,
)
plot_specgram(
resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)"
) )
plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type='kaiser_fast')).unsqueeze(0) librosa.resample(
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)") waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast"
)
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse) print("torchaudio and librosa kaiser fast MSE:", mse)
...@@ -371,71 +420,87 @@ configs = { ...@@ -371,71 +420,87 @@ configs = {
} }
for label in configs: for label in configs:
times, rows = [], [] times, rows = [], []
sample_rate = configs[label][0] sample_rate = configs[label][0]
resample_rate = configs[label][1] resample_rate = configs[label][1]
waveform = get_sine_sweep(sample_rate) waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings # sinc 64 zero-crossings
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64) f_time = benchmark_resample(
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64) "functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64
times.append([None, 1000 * f_time, 1000 * t_time]) )
rows.append(f"sinc (width 64)") t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64
# sinc 6 zero-crossings )
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16) times.append([None, 1000 * f_time, 1000 * t_time])
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16) rows.append("sinc (width 64)")
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append(f"sinc (width 16)") # sinc 6 zero-crossings
f_time = benchmark_resample(
# kaiser best "functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best") )
f_time = benchmark_resample( t_time = benchmark_resample(
"functional", "transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16
waveform, )
sample_rate, times.append([None, 1000 * f_time, 1000 * t_time])
resample_rate, rows.append("sinc (width 16)")
lowpass_filter_width=64,
rolloff=0.9475937167399596, # kaiser best
resampling_method="kaiser_window", lib_time = benchmark_resample(
beta=14.769656459379492) "librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best"
t_time = benchmark_resample( )
"transforms", f_time = benchmark_resample(
waveform, "functional",
sample_rate, waveform,
resample_rate, sample_rate,
lowpass_filter_width=64, resample_rate,
rolloff=0.9475937167399596, lowpass_filter_width=64,
resampling_method="kaiser_window", rolloff=0.9475937167399596,
beta=14.769656459379492) resampling_method="kaiser_window",
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time]) beta=14.769656459379492,
rows.append(f"kaiser_best") )
t_time = benchmark_resample(
# kaiser fast "transforms",
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast") waveform,
f_time = benchmark_resample( sample_rate,
"functional", resample_rate,
waveform, lowpass_filter_width=64,
sample_rate, rolloff=0.9475937167399596,
resample_rate, resampling_method="kaiser_window",
lowpass_filter_width=16, beta=14.769656459379492,
rolloff=0.85, )
resampling_method="kaiser_window", times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
beta=8.555504641634386) rows.append("kaiser_best")
t_time = benchmark_resample(
"transforms", # kaiser fast
waveform, lib_time = benchmark_resample(
sample_rate, "librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast"
resample_rate, )
lowpass_filter_width=16, f_time = benchmark_resample(
rolloff=0.85, "functional",
resampling_method="kaiser_window", waveform,
beta=8.555504641634386) sample_rate,
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time]) resample_rate,
rows.append(f"kaiser_fast") lowpass_filter_width=16,
rolloff=0.85,
df = pd.DataFrame(times, resampling_method="kaiser_window",
columns=["librosa", "functional", "transforms"], beta=8.555504641634386,
index=rows) )
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"],df.columns]) t_time = benchmark_resample(
display(df.round(2)) "transforms",
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append("kaiser_fast")
df = pd.DataFrame(
times, columns=["librosa", "functional", "transforms"], index=rows
)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
display(df.round(2))
...@@ -15,25 +15,25 @@ Recognition <https://arxiv.org/abs/2007.09127>`__. ...@@ -15,25 +15,25 @@ Recognition <https://arxiv.org/abs/2007.09127>`__.
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
# #
# The process of alignment looks like the following. # The process of alignment looks like the following.
# #
# 1. Estimate the frame-wise label probability from audio waveform # 1. Estimate the frame-wise label probability from audio waveform
# 2. Generate the trellis matrix which represents the probability of # 2. Generate the trellis matrix which represents the probability of
# labels aligned at time step. # labels aligned at time step.
# 3. Find the most likely path from the trellis matrix. # 3. Find the most likely path from the trellis matrix.
# #
# In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for # In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
# acoustic feature extraction. # acoustic feature extraction.
# #
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline # %matplotlib inline
...@@ -47,48 +47,48 @@ import matplotlib ...@@ -47,48 +47,48 @@ import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import IPython import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
print(device) print(device)
SPEECH_URL = 'https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav' SPEECH_URL = "https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = '_assets/speech.wav' SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE): if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True) os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, 'wb') as file: with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content) file.write(requests.get(SPEECH_URL).content)
###################################################################### ######################################################################
# Generate frame-wise label probability # Generate frame-wise label probability
# ------------------------------------- # -------------------------------------
# #
# The first step is to generate the label class porbability of each aduio # The first step is to generate the label class porbability of each aduio
# frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use # frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
# :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`. # :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
# #
# ``torchaudio`` provides easy access to pretrained models with associated # ``torchaudio`` provides easy access to pretrained models with associated
# labels. # labels.
# #
# .. note:: # .. note::
# #
# In the subsequent sections, we will compute the probability in # In the subsequent sections, we will compute the probability in
# log-domain to avoid numerical instability. For this purpose, we # log-domain to avoid numerical instability. For this purpose, we
# normalize the ``emission`` with :py:func:`torch.log_softmax`. # normalize the ``emission`` with :py:func:`torch.log_softmax`.
# #
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device) model = bundle.get_model().to(device)
labels = bundle.get_labels() labels = bundle.get_labels()
with torch.inference_mode(): with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE) waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device)) emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1) emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
...@@ -107,16 +107,16 @@ plt.show() ...@@ -107,16 +107,16 @@ plt.show()
###################################################################### ######################################################################
# Generate alignment probability (trellis) # Generate alignment probability (trellis)
# ---------------------------------------- # ----------------------------------------
# #
# From the emission matrix, next we generate the trellis which represents # From the emission matrix, next we generate the trellis which represents
# the probability of transcript labels occur at each time frame. # the probability of transcript labels occur at each time frame.
# #
# Trellis is 2D matrix with time axis and label axis. The label axis # Trellis is 2D matrix with time axis and label axis. The label axis
# represents the transcript that we are aligning. In the following, we use # represents the transcript that we are aligning. In the following, we use
# :math:`t` to denote the index in time axis and :math:`j` to denote the # :math:`t` to denote the index in time axis and :math:`j` to denote the
# index in label axis. :math:`c_j` represents the label at label index # index in label axis. :math:`c_j` represents the label at label index
# :math:`j`. # :math:`j`.
# #
# To generate, the probability of time step :math:`t+1`, we look at the # To generate, the probability of time step :math:`t+1`, we look at the
# trellis from time step :math:`t` and emission at time step :math:`t+1`. # trellis from time step :math:`t` and emission at time step :math:`t+1`.
# There are two path to reach to time step :math:`t+1` with label # There are two path to reach to time step :math:`t+1` with label
...@@ -125,53 +125,55 @@ plt.show() ...@@ -125,53 +125,55 @@ plt.show()
# :math:`t` to :math:`t+1`. The other case is where the label was # :math:`t` to :math:`t+1`. The other case is where the label was
# :math:`c_j` at :math:`t` and it transitioned to the next label # :math:`c_j` at :math:`t` and it transitioned to the next label
# :math:`c_{j+1}` at :math:`t+1`. # :math:`c_{j+1}` at :math:`t+1`.
# #
# The follwoing diagram illustrates this transition. # The follwoing diagram illustrates this transition.
# #
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png # .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png
# #
# Since we are looking for the most likely transitions, we take the more # Since we are looking for the most likely transitions, we take the more
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is # likely path for the value of :math:`k_{(t+1, j+1)}`, that is
# #
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )` # :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
# #
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)` # where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
# represents the probability of label :math:`c_j` at time step :math:`t`. # represents the probability of label :math:`c_j` at time step :math:`t`.
# :math:`repeat` represents the blank token from CTC formulation. (For the # :math:`repeat` represents the blank token from CTC formulation. (For the
# detail of CTC algorithm, please refer to the *Sequence Modeling with CTC* # detail of CTC algorithm, please refer to the *Sequence Modeling with CTC*
# [`distill.pub <https://distill.pub/2017/ctc/>`__]) # [`distill.pub <https://distill.pub/2017/ctc/>`__])
# #
transcript = 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT' transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
dictionary = {c: i for i, c in enumerate(labels)} dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in transcript] tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens))) print(list(zip(transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens. # Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence) # The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code. # The extra dim for time axis is for simplification of the code.
trellis = torch.full((num_frame+1, num_tokens+1), -float('inf')) trellis = torch.full((num_frame + 1, num_tokens + 1), -float("inf"))
trellis[:, 0] = 0 trellis[:, 0] = 0
for t in range(num_frame): for t in range(num_frame):
trellis[t+1, 1:] = torch.maximum( trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token # Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id], trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token # Score for changing to the next token
trellis[t, :-1] + emission[t, tokens], trellis[t, :-1] + emission[t, tokens],
) )
return trellis return trellis
trellis = get_trellis(emission, tokens) trellis = get_trellis(emission, tokens)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ ################################################################################
plt.imshow(trellis[1:, 1:].T, origin='lower') plt.imshow(trellis[1:, 1:].T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.colorbar() plt.colorbar()
plt.show() plt.show()
...@@ -179,84 +181,88 @@ plt.show() ...@@ -179,84 +181,88 @@ plt.show()
###################################################################### ######################################################################
# In the above visualization, we can see that there is a trace of high # In the above visualization, we can see that there is a trace of high
# probability crossing the matrix diagonally. # probability crossing the matrix diagonally.
# #
###################################################################### ######################################################################
# Find the most likely path (backtracking) # Find the most likely path (backtracking)
# ---------------------------------------- # ----------------------------------------
# #
# Once the trellis is generated, we will traverse it following the # Once the trellis is generated, we will traverse it following the
# elements with high probability. # elements with high probability.
# #
# We will start from the last label index with the time step of highest # We will start from the last label index with the time step of highest
# probability, then, we traverse back in time, picking stay # probability, then, we traverse back in time, picking stay
# (:math:`c_j \rightarrow c_j`) or transition # (:math:`c_j \rightarrow c_j`) or transition
# (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition # (:math:`c_j \rightarrow c_{j+1}`), based on the post-transition
# probability :math:`k_{t, j} p(t+1, c_{j+1})` or # probability :math:`k_{t, j} p(t+1, c_{j+1})` or
# :math:`k_{t, j+1} p(t+1, repeat)`. # :math:`k_{t, j+1} p(t+1, repeat)`.
# #
# Transition is done once the label reaches the beginning. # Transition is done once the label reaches the beginning.
# #
# The trellis matrix is used for path-finding, but for the final # The trellis matrix is used for path-finding, but for the final
# probability of each segment, we take the frame-wise probability from # probability of each segment, we take the frame-wise probability from
# emission matrix. # emission matrix.
# #
@dataclass @dataclass
class Point: class Point:
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def backtrack(trellis, emission, tokens, blank_id=0): def backtrack(trellis, emission, tokens, blank_id=0):
# Note: # Note:
# j and t are indices for trellis, which has extra dimensions # j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning. # for time and tokens at the beginning.
# When refering to time frame index `T` in trellis, # When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`. # the corresponding index in emission is `T-1`.
# Similarly, when refering to token index `J` in trellis, # Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`. # the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1 j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item() t_start = torch.argmax(trellis[:, j]).item()
path = [] path = []
for t in range(t_start, 0, -1): for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change # 1. Figure out if the current position was stay or change
# Note (again): # Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension. # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T. # Score for token staying the same from time frame J-1 to T.
stayed = trellis[t-1, j] + emission[t-1, blank_id] stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T. # Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t-1, j-1] + emission[t-1, tokens[j-1]] changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability. # 2. Store the path with frame-wise probability.
prob = emission[t-1, tokens[j-1] if changed > stayed else 0].exp().item() prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate. # Return token index and time index in non-trellis coordinate.
path.append(Point(j-1, t-1, prob)) path.append(Point(j - 1, t - 1, prob))
# 3. Update the token # 3. Update the token
if changed > stayed: if changed > stayed:
j -= 1 j -= 1
if j == 0: if j == 0:
break break
else: else:
raise ValueError('Failed to align') raise ValueError("Failed to align")
return path[::-1] return path[::-1]
path = backtrack(trellis, emission, tokens) path = backtrack(trellis, emission, tokens)
print(path) print(path)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ ################################################################################
def plot_trellis_with_path(trellis, path): def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, p in enumerate(path): for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float('nan') trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path[1:, 1:].T, origin='lower') plt.imshow(trellis_with_path[1:, 1:].T, origin="lower")
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking") plt.title("The path found by backtracking")
...@@ -265,82 +271,94 @@ plt.show() ...@@ -265,82 +271,94 @@ plt.show()
###################################################################### ######################################################################
# Looking good. Now this path contains repetations for the same labels, so # Looking good. Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript. # let’s merge them to make it close to the original transcript.
# #
# When merging the multiple path points, we simply take the average # When merging the multiple path points, we simply take the average
# probability for the merged segments. # probability for the merged segments.
# #
# Merge the labels # Merge the labels
@dataclass @dataclass
class Segment: class Segment:
label: str label: str
start: int start: int
end: int end: int
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
@property
def length(self):
return self.end - self.start
def merge_repeats(path): def merge_repeats(path):
i1, i2 = 0, 0 i1, i2 = 0, 0
segments = [] segments = []
while i1 < len(path): while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index: while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1 i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(Segment(transcript[path[i1].token_index], path[i1].time_index, path[i2-1].time_index + 1, score)) segments.append(
i1 = i2 Segment(
return segments transcript[path[i1].token_index],
path[i1].time_index,
path[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
segments = merge_repeats(path) segments = merge_repeats(path)
for seg in segments: for seg in segments:
print(seg) print(seg)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ ################################################################################
def plot_trellis_with_segments(trellis, segments, transcript): def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != '|': if seg.label != "|":
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan') trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin='lower') ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([]) ax1.set_xticks([])
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != '|': if seg.label != "|":
ax1.annotate(seg.label, (seg.start + .7, i + 0.3), weight='bold') ax1.annotate(seg.label, (seg.start + 0.7, i + 0.3), weight="bold")
ax1.annotate(f'{seg.score:.2f}', (seg.start - .3, i + 4.3)) ax1.annotate(f"{seg.score:.2f}", (seg.start - 0.3, i + 4.3))
ax2.set_title("Label probability with and without repetation") ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
for seg in segments: for seg in segments:
if seg.label != '|': if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + .4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + .8, -0.07), weight='bold') ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color='gray', alpha=0.5, edgecolor='black') ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
for p in path: for p in path:
label = transcript[p.token_index] label = transcript[p.token_index]
if label != '|': if label != "|":
xs.append(p.time_index + 1) xs.append(p.time_index + 1)
hs.append(p.score) hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color='black') ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim()) ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(-0.1, 1.1) ax2.set_ylim(-0.1, 1.1)
plot_trellis_with_segments(trellis, segments, transcript) plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout() plt.tight_layout()
...@@ -351,93 +369,109 @@ plt.show() ...@@ -351,93 +369,109 @@ plt.show()
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of # as the word boundary, so we merge the segments before each occurance of
# ``'|'``. # ``'|'``.
# #
# Then, finally, we segment the original audio into segmented audio and # Then, finally, we segment the original audio into segmented audio and
# listen to them to see if the segmentation is correct. # listen to them to see if the segmentation is correct.
# #
# Merge words # Merge words
def merge_words(segments, separator='|'): def merge_words(segments, separator="|"):
words = [] words = []
i1, i2 = 0, 0 i1, i2 = 0, 0
while i1 < len(segments): while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator: if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2: if i1 != i2:
segs = segments[i1:i2] segs = segments[i1:i2]
word = ''.join([seg.label for seg in segs]) word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) score = sum(seg.score * seg.length for seg in segs) / sum(
words.append(Segment(word, segments[i1].start, segments[i2-1].end, score)) seg.length for seg in segs
i1 = i2 + 1 )
i2 = i1 words.append(
else: Segment(word, segments[i1].start, segments[i2 - 1].end, score)
i2 += 1 )
return words i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
word_segments = merge_words(segments) word_segments = merge_words(segments)
for word in word_segments: for word in word_segments:
print(word) print(word)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ ################################################################################
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != '|': if seg.label != "|":
trellis_with_path[seg.start+1:seg.end+1, i+1] = float('nan') trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower') ax1.imshow(trellis_with_path[1:, 1:].T, origin="lower")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
for word in word_segments: for word in word_segments:
ax1.axvline(word.start - 0.5) ax1.axvline(word.start - 0.5)
ax1.axvline(word.end - 0.5) ax1.axvline(word.end - 0.5)
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != '|': if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i + 0.3)) ax1.annotate(seg.label, (seg.start, i + 0.3))
ax1.annotate(f'{seg.score:.2f}', (seg.start , i + 4), fontsize=8) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 4), fontsize=8)
# The original waveform # The original waveform
ratio = waveform.size(0) / (trellis.size(0) - 1) ratio = waveform.size(0) / (trellis.size(0) - 1)
ax2.plot(waveform) ax2.plot(waveform)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start x0 = ratio * word.start
x1 = ratio * word.end x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color='red') ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f'{word.score:.2f}', (x0, 0.8)) ax2.annotate(f"{word.score:.2f}", (x0, 0.8))
for seg in segments: for seg in segments:
if seg.label != '|': if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9)) ax2.annotate(seg.label, (seg.start * ratio, 0.9))
xticks = ax2.get_xticks() xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate) plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.set_xlabel('time [second]') ax2.set_xlabel("time [second]")
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0) ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1)) ax2.set_xlim(0, waveform.size(-1))
plot_alignments(trellis, segments, word_segments, waveform[0],)
plot_alignments(
trellis,
segments,
word_segments,
waveform[0],
)
plt.show() plt.show()
# A trick to embed the resulting audio to the generated file. # A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell, # `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell. # and there should be only one call par cell.
def display_segment(i): def display_segment(i):
ratio = waveform.size(1) / (trellis.size(0) - 1) ratio = waveform.size(1) / (trellis.size(0) - 1)
word = word_segments[i] word = word_segments[i]
x0 = int(ratio * word.start) x0 = int(ratio * word.start)
x1 = int(ratio * word.end) x1 = int(ratio * word.end)
filename = f"_assets/{i}_{word.label}.wav" filename = f"_assets/{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate) torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec") print(
return IPython.display.Audio(filename) f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec"
)
return IPython.display.Audio(filename)
###################################################################### ######################################################################
# #
# Generate the audio for each segment # Generate the audio for each segment
print(transcript) print(transcript)
...@@ -445,54 +479,54 @@ IPython.display.Audio(SPEECH_FILE) ...@@ -445,54 +479,54 @@ IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
# #
display_segment(0) display_segment(0)
###################################################################### ######################################################################
# #
display_segment(1) display_segment(1)
###################################################################### ######################################################################
# #
display_segment(2) display_segment(2)
###################################################################### ######################################################################
# #
display_segment(3) display_segment(3)
###################################################################### ######################################################################
# #
display_segment(4) display_segment(4)
###################################################################### ######################################################################
# #
display_segment(5) display_segment(5)
###################################################################### ######################################################################
# #
display_segment(6) display_segment(6)
###################################################################### ######################################################################
# #
display_segment(7) display_segment(7)
###################################################################### ######################################################################
# #
display_segment(8) display_segment(8)
###################################################################### ######################################################################
# Conclusion # Conclusion
# ---------- # ----------
# #
# In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to # In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to
# perform CTC segmentation for forced alignment. # perform CTC segmentation for forced alignment.
# #
...@@ -9,11 +9,11 @@ MVDR with torchaudio ...@@ -9,11 +9,11 @@ MVDR with torchaudio
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
# #
# This is a tutorial on how to apply MVDR beamforming by using `torchaudio <https://github.com/pytorch/audio>`__. # This is a tutorial on how to apply MVDR beamforming by using `torchaudio <https://github.com/pytorch/audio>`__.
# #
# Steps # Steps
# #
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise # - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
# magnitude by the mixture magnitude. # magnitude by the mixture magnitude.
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``) # - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``)
...@@ -26,22 +26,22 @@ MVDR with torchaudio ...@@ -26,22 +26,22 @@ MVDR with torchaudio
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# First, we import the necessary packages and retrieve the data. # First, we import the necessary packages and retrieve the data.
# #
# The multi-channel audio example is selected from # The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__ # `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset. # dataset.
# #
# The original filename is # The original filename is
# #
# ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav`` # ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
# #
# which was generated with; # which was generated with;
# #
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0) # - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # noqa: E501
# #
import os import os
import requests import requests
...@@ -50,48 +50,48 @@ import torchaudio ...@@ -50,48 +50,48 @@ import torchaudio
import IPython.display as ipd import IPython.display as ipd
torch.random.manual_seed(0) torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
print(device) print(device)
filenames = [ filenames = [
'mix.wav', "mix.wav",
'reverb_clean.wav', "reverb_clean.wav",
'clean.wav', "clean.wav",
] ]
base_url = 'https://download.pytorch.org/torchaudio/tutorial-assets/mvdr' base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
for filename in filenames: for filename in filenames:
os.makedirs('_assets', exist_ok=True) os.makedirs("_assets", exist_ok=True)
if not os.path.exists(filename): if not os.path.exists(filename):
with open(f'_assets/{filename}', 'wb') as file: with open(f"_assets/{filename}", "wb") as file:
file.write(requests.get(f'{base_url}/{filename}').content) file.write(requests.get(f"{base_url}/{filename}").content)
###################################################################### ######################################################################
# Generate the Ideal Ratio Mask (IRM) # Generate the Ideal Ratio Mask (IRM)
# ----------------------------------- # -----------------------------------
# #
###################################################################### ######################################################################
# Loading audio data # Loading audio data
# ~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~
# #
mix, sr = torchaudio.load('_assets/mix.wav') mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load('_assets/reverb_clean.wav') reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load('_assets/clean.wav') clean, sr3 = torchaudio.load("_assets/clean.wav")
assert sr == sr2 assert sr == sr2
noise = mix - reverb_clean noise = mix - reverb_clean
###################################################################### ######################################################################
# #
# .. note:: # .. note::
# The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT. # The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
# We need to convert the dtype of the waveforms to ``torch.double`` # We need to convert the dtype of the waveforms to ``torch.double``
# #
mix = mix.to(torch.double) mix = mix.to(torch.double)
noise = noise.to(torch.double) noise = noise.to(torch.double)
...@@ -101,7 +101,7 @@ reverb_clean = reverb_clean.to(torch.double) ...@@ -101,7 +101,7 @@ reverb_clean = reverb_clean.to(torch.double)
###################################################################### ######################################################################
# Compute STFT # Compute STFT
# ~~~~~~~~~~~~ # ~~~~~~~~~~~~
# #
stft = torchaudio.transforms.Spectrogram( stft = torchaudio.transforms.Spectrogram(
n_fft=1024, n_fft=1024,
...@@ -118,15 +118,14 @@ spec_noise = stft(noise) ...@@ -118,15 +118,14 @@ spec_noise = stft(noise)
###################################################################### ######################################################################
# Generate the Ideal Ratio Mask (IRM) # Generate the Ideal Ratio Mask (IRM)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# .. note:: # .. note::
# We found using the mask directly peforms better than using the # We found using the mask directly peforms better than using the
# square root of it. This is slightly different from the definition of IRM. # square root of it. This is slightly different from the definition of IRM.
# #
def get_irms(spec_clean, spec_noise, spec_mix): def get_irms(spec_clean, spec_noise):
mag_mix = spec_mix.abs() ** 2
mag_clean = spec_clean.abs() ** 2 mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2 mag_noise = spec_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise) irm_speech = mag_clean / (mag_clean + mag_noise)
...@@ -134,25 +133,26 @@ def get_irms(spec_clean, spec_noise, spec_mix): ...@@ -134,25 +133,26 @@ def get_irms(spec_clean, spec_noise, spec_mix):
return irm_speech, irm_noise return irm_speech, irm_noise
###################################################################### ######################################################################
# .. note:: # .. note::
# We use reverberant clean speech as the target here, # We use reverberant clean speech as the target here,
# you can also set it to dry clean speech. # you can also set it to dry clean speech.
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix) irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
###################################################################### ######################################################################
# Apply MVDR # Apply MVDR
# ---------- # ----------
# #
###################################################################### ######################################################################
# Apply MVDR beamforming by using multi-channel masks # Apply MVDR beamforming by using multi-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
results_multi = {} results_multi = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']: for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True) mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
stft_est = mvdr(spec_mix, irm_speech, irm_noise) stft_est = mvdr(spec_mix, irm_speech, irm_noise)
est = istft(stft_est, length=mix.shape[-1]) est = istft(stft_est, length=mix.shape[-1])
...@@ -161,13 +161,15 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']: ...@@ -161,13 +161,15 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
###################################################################### ######################################################################
# Apply MVDR beamforming by using single-channel masks # Apply MVDR beamforming by using single-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# We use the 1st channel as an example. # We use the 1st channel as an example.
# The channel selection may depend on the design of the microphone array # The channel selection may depend on the design of the microphone array
results_single = {} results_single = {}
for solution in ['ref_channel', 'stv_evd', 'stv_power']: for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False) mvdr = torchaudio.transforms.MVDR(
ref_channel=0, solution=solution, multi_mask=False
)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0]) stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1]) est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est results_single[solution] = est
...@@ -175,7 +177,8 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']: ...@@ -175,7 +177,8 @@ for solution in ['ref_channel', 'stv_evd', 'stv_power']:
###################################################################### ######################################################################
# Compute Si-SDR scores # Compute Si-SDR scores
# ~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~
# #
def si_sdr(estimate, reference, epsilon=1e-8): def si_sdr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean() estimate = estimate - estimate.mean()
...@@ -196,96 +199,101 @@ def si_sdr(estimate, reference, epsilon=1e-8): ...@@ -196,96 +199,101 @@ def si_sdr(estimate, reference, epsilon=1e-8):
sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return sisdr.item() return sisdr.item()
###################################################################### ######################################################################
# Results # Results
# ------- # -------
# #
###################################################################### ######################################################################
# Single-channel mask results # Single-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
for solution in results_single: for solution in results_single:
print(solution+": ", si_sdr(results_single[solution][None,...], reverb_clean[0:1])) print(
solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1])
)
###################################################################### ######################################################################
# Multi-channel mask results # Multi-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
for solution in results_multi: for solution in results_multi:
print(solution+": ", si_sdr(results_multi[solution][None,...], reverb_clean[0:1])) print(
solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1])
)
###################################################################### ######################################################################
# Original audio # Original audio
# -------------- # --------------
# #
###################################################################### ######################################################################
# Mixture speech # Mixture speech
# ~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~
# #
ipd.Audio(mix[0], rate=16000) ipd.Audio(mix[0], rate=16000)
###################################################################### ######################################################################
# Noise # Noise
# ~~~~~ # ~~~~~
# #
ipd.Audio(noise[0], rate=16000) ipd.Audio(noise[0], rate=16000)
###################################################################### ######################################################################
# Clean speech # Clean speech
# ~~~~~~~~~~~~ # ~~~~~~~~~~~~
# #
ipd.Audio(clean[0], rate=16000) ipd.Audio(clean[0], rate=16000)
###################################################################### ######################################################################
# Enhanced audio # Enhanced audio
# -------------- # --------------
# #
###################################################################### ######################################################################
# Multi-channel mask, ref_channel solution # Multi-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_multi['ref_channel'], rate=16000) ipd.Audio(results_multi["ref_channel"], rate=16000)
###################################################################### ######################################################################
# Multi-channel mask, stv_evd solution # Multi-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_multi['stv_evd'], rate=16000) ipd.Audio(results_multi["stv_evd"], rate=16000)
###################################################################### ######################################################################
# Multi-channel mask, stv_power solution # Multi-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_multi['stv_power'], rate=16000) ipd.Audio(results_multi["stv_power"], rate=16000)
###################################################################### ######################################################################
# Single-channel mask, ref_channel solution # Single-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_single['ref_channel'], rate=16000) ipd.Audio(results_single["ref_channel"], rate=16000)
###################################################################### ######################################################################
# Single-channel mask, stv_evd solution # Single-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_single['stv_evd'], rate=16000) ipd.Audio(results_single["stv_evd"], rate=16000)
###################################################################### ######################################################################
# Single-channel mask, stv_power solution # Single-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_single['stv_power'], rate=16000) ipd.Audio(results_single["stv_power"], rate=16000)
...@@ -14,28 +14,28 @@ pre-trained models from wav2vec 2.0 ...@@ -14,28 +14,28 @@ pre-trained models from wav2vec 2.0
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
# #
# The process of speech recognition looks like the following. # The process of speech recognition looks like the following.
# #
# 1. Extract the acoustic features from audio waveform # 1. Extract the acoustic features from audio waveform
# #
# 2. Estimate the class of the acoustic features frame-by-frame # 2. Estimate the class of the acoustic features frame-by-frame
# #
# 3. Generate hypothesis from the sequence of the class probabilities # 3. Generate hypothesis from the sequence of the class probabilities
# #
# Torchaudio provides easy access to the pre-trained weights and # Torchaudio provides easy access to the pre-trained weights and
# associated information, such as the expected sample rate and class # associated information, such as the expected sample rate and class
# labels. They are bundled together and available under # labels. They are bundled together and available under
# ``torchaudio.pipelines`` module. # ``torchaudio.pipelines`` module.
# #
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline # %matplotlib inline
...@@ -48,52 +48,52 @@ import matplotlib ...@@ -48,52 +48,52 @@ import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import IPython import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
print(device) print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SPEECH_FILE = "_assets/speech.wav" SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE): if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True) os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, 'wb') as file: with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content) file.write(requests.get(SPEECH_URL).content)
###################################################################### ######################################################################
# Creating a pipeline # Creating a pipeline
# ------------------- # -------------------
# #
# First, we will create a Wav2Vec2 model that performs the feature # First, we will create a Wav2Vec2 model that performs the feature
# extraction and the classification. # extraction and the classification.
# #
# There are two types of Wav2Vec2 pre-trained weights available in # There are two types of Wav2Vec2 pre-trained weights available in
# torchaudio. The ones fine-tuned for ASR task, and the ones not # torchaudio. The ones fine-tuned for ASR task, and the ones not
# fine-tuned. # fine-tuned.
# #
# Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They # Wav2Vec2 (and HuBERT) models are trained in self-supervised manner. They
# are firstly trained with audio only for representation learning, then # are firstly trained with audio only for representation learning, then
# fine-tuned for a specific task with additional labels. # fine-tuned for a specific task with additional labels.
# #
# The pre-trained weights without fine-tuning can be fine-tuned # The pre-trained weights without fine-tuning can be fine-tuned
# for other downstream tasks as well, but this tutorial does not # for other downstream tasks as well, but this tutorial does not
# cover that. # cover that.
# #
# We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here. # We will use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.
# #
# There are multiple models available as # There are multiple models available as
# :py:mod:`torchaudio.pipelines`. Please check the documentation for # :py:mod:`torchaudio.pipelines`. Please check the documentation for
# the detail of how they are trained. # the detail of how they are trained.
# #
# The bundle object provides the interface to instantiate model and other # The bundle object provides the interface to instantiate model and other
# information. Sampling rate and the class labels are found as follow. # information. Sampling rate and the class labels are found as follow.
# #
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
...@@ -105,7 +105,7 @@ print("Labels:", bundle.get_labels()) ...@@ -105,7 +105,7 @@ print("Labels:", bundle.get_labels())
###################################################################### ######################################################################
# Model can be constructed as following. This process will automatically # Model can be constructed as following. This process will automatically
# fetch the pre-trained weights and load it into the model. # fetch the pre-trained weights and load it into the model.
# #
model = bundle.get_model().to(device) model = bundle.get_model().to(device)
...@@ -115,62 +115,62 @@ print(model.__class__) ...@@ -115,62 +115,62 @@ print(model.__class__)
###################################################################### ######################################################################
# Loading data # Loading data
# ------------ # ------------
# #
# We will use the speech data from `VOiCES # We will use the speech data from `VOiCES
# dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under # dataset <https://iqtlabs.github.io/voices/>`__, which is licensed under
# Creative Commos BY 4.0. # Creative Commos BY 4.0.
# #
IPython.display.Audio(SPEECH_FILE) IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
# To load data, we use :py:func:`torchaudio.load`. # To load data, we use :py:func:`torchaudio.load`.
# #
# If the sampling rate is different from what the pipeline expects, then # If the sampling rate is different from what the pipeline expects, then
# we can use :py:func:`torchaudio.functional.resample` for resampling. # we can use :py:func:`torchaudio.functional.resample` for resampling.
# #
# .. note:: # .. note::
# #
# - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well. # - :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.
# - When performing resampling multiple times on the same set of sample rates, # - When performing resampling multiple times on the same set of sample rates,
# using :py:func:`torchaudio.transforms.Resample` might improve the performace. # using :py:func:`torchaudio.transforms.Resample` might improve the performace.
# #
waveform, sample_rate = torchaudio.load(SPEECH_FILE) waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device) waveform = waveform.to(device)
if sample_rate != bundle.sample_rate: if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
###################################################################### ######################################################################
# Extracting acoustic features # Extracting acoustic features
# ---------------------------- # ----------------------------
# #
# The next step is to extract acoustic features from the audio. # The next step is to extract acoustic features from the audio.
# #
# .. note:: # .. note::
# Wav2Vec2 models fine-tuned for ASR task can perform feature # Wav2Vec2 models fine-tuned for ASR task can perform feature
# extraction and classification with one step, but for the sake of the # extraction and classification with one step, but for the sake of the
# tutorial, we also show how to perform feature extraction here. # tutorial, we also show how to perform feature extraction here.
# #
with torch.inference_mode(): with torch.inference_mode():
features, _ = model.extract_features(waveform) features, _ = model.extract_features(waveform)
###################################################################### ######################################################################
# The returned features is a list of tensors. Each tensor is the output of # The returned features is a list of tensors. Each tensor is the output of
# a transformer layer. # a transformer layer.
# #
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features))) fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features): for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu()) ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension") ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)") ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
...@@ -178,24 +178,24 @@ plt.show() ...@@ -178,24 +178,24 @@ plt.show()
###################################################################### ######################################################################
# Feature classification # Feature classification
# ---------------------- # ----------------------
# #
# Once the acoustic features are extracted, the next step is to classify # Once the acoustic features are extracted, the next step is to classify
# them into a set of categories. # them into a set of categories.
# #
# Wav2Vec2 model provides method to perform the feature extraction and # Wav2Vec2 model provides method to perform the feature extraction and
# classification in one step. # classification in one step.
# #
with torch.inference_mode(): with torch.inference_mode():
emission, _ = model(waveform) emission, _ = model(waveform)
###################################################################### ######################################################################
# The output is in the form of logits. It is not in the form of # The output is in the form of logits. It is not in the form of
# probability. # probability.
# #
# Let’s visualize this. # Let’s visualize this.
# #
plt.imshow(emission[0].cpu().T) plt.imshow(emission[0].cpu().T)
plt.title("Classification result") plt.title("Classification result")
...@@ -208,62 +208,63 @@ print("Class labels:", bundle.get_labels()) ...@@ -208,62 +208,63 @@ print("Class labels:", bundle.get_labels())
###################################################################### ######################################################################
# We can see that there are strong indications to certain labels across # We can see that there are strong indications to certain labels across
# the time line. # the time line.
# #
###################################################################### ######################################################################
# Generating transcripts # Generating transcripts
# ---------------------- # ----------------------
# #
# From the sequence of label probabilities, now we want to generate # From the sequence of label probabilities, now we want to generate
# transcripts. The process to generate hypotheses is often called # transcripts. The process to generate hypotheses is often called
# “decoding”. # “decoding”.
# #
# Decoding is more elaborate than simple classification because # Decoding is more elaborate than simple classification because
# decoding at certain time step can be affected by surrounding # decoding at certain time step can be affected by surrounding
# observations. # observations.
# #
# For example, take a word like ``night`` and ``knight``. Even if their # For example, take a word like ``night`` and ``knight``. Even if their
# prior probability distribution are differnt (in typical conversations, # prior probability distribution are differnt (in typical conversations,
# ``night`` would occur way more often than ``knight``), to accurately # ``night`` would occur way more often than ``knight``), to accurately
# generate transcripts with ``knight``, such as ``a knight with a sword``, # generate transcripts with ``knight``, such as ``a knight with a sword``,
# the decoding process has to postpone the final decision until it sees # the decoding process has to postpone the final decision until it sees
# enough context. # enough context.
# #
# There are many decoding techniques proposed, and they require external # There are many decoding techniques proposed, and they require external
# resources, such as word dictionary and language models. # resources, such as word dictionary and language models.
# #
# In this tutorial, for the sake of simplicity, we will perform greedy # In this tutorial, for the sake of simplicity, we will perform greedy
# decoding which does not depend on such external components, and simply # decoding which does not depend on such external components, and simply
# pick up the best hypothesis at each time step. Therefore, the context # pick up the best hypothesis at each time step. Therefore, the context
# information are not used, and only one transcript can be generated. # information are not used, and only one transcript can be generated.
# #
# We start by defining greedy decoding algorithm. # We start by defining greedy decoding algorithm.
# #
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0): def __init__(self, labels, blank=0):
super().__init__() super().__init__()
self.labels = labels self.labels = labels
self.blank = blank self.blank = blank
def forward(self, emission: torch.Tensor) -> str: def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string """Given a sequence emission over labels, get the best path string
Args: Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns: Returns:
str: The resulting transcript str: The resulting transcript
""" """
indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1) indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank] indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices]) return "".join([self.labels[i] for i in indices])
###################################################################### ######################################################################
# Now create the decoder object and decode the transcript. # Now create the decoder object and decode the transcript.
# #
decoder = GreedyCTCDecoder(labels=bundle.get_labels()) decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0]) transcript = decoder(emission[0])
...@@ -271,7 +272,7 @@ transcript = decoder(emission[0]) ...@@ -271,7 +272,7 @@ transcript = decoder(emission[0])
###################################################################### ######################################################################
# Let’s check the result and listen again to the audio. # Let’s check the result and listen again to the audio.
# #
print(transcript) print(transcript)
IPython.display.Audio(SPEECH_FILE) IPython.display.Audio(SPEECH_FILE)
...@@ -283,19 +284,19 @@ IPython.display.Audio(SPEECH_FILE) ...@@ -283,19 +284,19 @@ IPython.display.Audio(SPEECH_FILE)
# `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a # `here <https://distill.pub/2017/ctc/>`__. In CTC a blank token (ϵ) is a
# special token which represents a repetition of the previous symbol. In # special token which represents a repetition of the previous symbol. In
# decoding, these are simply ignored. # decoding, these are simply ignored.
# #
###################################################################### ######################################################################
# Conclusion # Conclusion
# ---------- # ----------
# #
# In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to # In this tutorial, we looked at how to use :py:mod:`torchaudio.pipelines` to
# perform acoustic feature extraction and speech recognition. Constructing # perform acoustic feature extraction and speech recognition. Constructing
# a model and getting the emission is as short as two lines. # a model and getting the emission is as short as two lines.
# #
# :: # ::
# #
# model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model() # model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
# emission = model(waveforms, ...) # emission = model(waveforms, ...)
# #
...@@ -10,24 +10,24 @@ Text-to-Speech with Tacotron2 ...@@ -10,24 +10,24 @@ Text-to-Speech with Tacotron2
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
# #
# This tutorial shows how to build text-to-speech pipeline, using the # This tutorial shows how to build text-to-speech pipeline, using the
# pretrained Tacotron2 in torchaudio. # pretrained Tacotron2 in torchaudio.
# #
# The text-to-speech pipeline goes as follows: # The text-to-speech pipeline goes as follows:
# #
# 1. Text preprocessing # 1. Text preprocessing
# #
# First, the input text is encoded into a list of symbols. In this # First, the input text is encoded into a list of symbols. In this
# tutorial, we will use English characters and phonemes as the symbols. # tutorial, we will use English characters and phonemes as the symbols.
# #
# 2. Spectrogram generation # 2. Spectrogram generation
# #
# From the encoded text, a spectrogram is generated. We use ``Tacotron2`` # From the encoded text, a spectrogram is generated. We use ``Tacotron2``
# model for this. # model for this.
# #
# 3. Time-domain conversion # 3. Time-domain conversion
# #
# The last step is converting the spectrogram into the waveform. The # The last step is converting the spectrogram into the waveform. The
# process to generate speech from spectrogram is also called Vocoder. # process to generate speech from spectrogram is also called Vocoder.
# In this tutorial, three different vocoders are used, # In this tutorial, three different vocoders are used,
...@@ -35,23 +35,23 @@ Text-to-Speech with Tacotron2 ...@@ -35,23 +35,23 @@ Text-to-Speech with Tacotron2
# `Griffin-Lim <https://pytorch.org/audio/stable/transforms.html#griffinlim>`__, # `Griffin-Lim <https://pytorch.org/audio/stable/transforms.html#griffinlim>`__,
# and # and
# `Nvidia's WaveGlow <https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/>`__. # `Nvidia's WaveGlow <https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/>`__.
# #
# #
# The following figure illustrates the whole process. # The following figure illustrates the whole process.
# #
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png # .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png
# #
# All the related components are bundled in :py:func:`torchaudio.pipelines.Tacotron2TTSBundle`, # All the related components are bundled in :py:func:`torchaudio.pipelines.Tacotron2TTSBundle`,
# but this tutorial will also cover the process under the hood. # but this tutorial will also cover the process under the hood.
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# First, we install the necessary dependencies. In addition to # First, we install the necessary dependencies. In addition to
# ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based # ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based
# encoding. # encoding.
# #
# When running this example in notebook, install DeepPhonemizer # When running this example in notebook, install DeepPhonemizer
# !pip3 install deep_phonemizer # !pip3 install deep_phonemizer
...@@ -63,7 +63,7 @@ import matplotlib.pyplot as plt ...@@ -63,7 +63,7 @@ import matplotlib.pyplot as plt
import IPython import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -76,36 +76,38 @@ print(device) ...@@ -76,36 +76,38 @@ print(device)
###################################################################### ######################################################################
# Text Processing # Text Processing
# --------------- # ---------------
# #
###################################################################### ######################################################################
# Character-based encoding # Character-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~
# #
# In this section, we will go through how the character-based encoding # In this section, we will go through how the character-based encoding
# works. # works.
# #
# Since the pre-trained Tacotron2 model expects specific set of symbol # Since the pre-trained Tacotron2 model expects specific set of symbol
# tables, the same functionalities available in ``torchaudio``. This # tables, the same functionalities available in ``torchaudio``. This
# section is more for the explanation of the basis of encoding. # section is more for the explanation of the basis of encoding.
# #
# Firstly, we define the set of symbols. For example, we can use # Firstly, we define the set of symbols. For example, we can use
# ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the # ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the
# each character of the input text into the index of the corresponding # each character of the input text into the index of the corresponding
# symbol in the table. # symbol in the table.
# #
# The following is an example of such processing. In the example, symbols # The following is an example of such processing. In the example, symbols
# that are not in the table are ignored. # that are not in the table are ignored.
# #
symbols = '_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz' symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"
look_up = {s: i for i, s in enumerate(symbols)} look_up = {s: i for i, s in enumerate(symbols)}
symbols = set(symbols) symbols = set(symbols)
def text_to_sequence(text): def text_to_sequence(text):
text = text.lower() text = text.lower()
return [look_up[s] for s in text if s in symbols] return [look_up[s] for s in text if s in symbols]
text = "Hello world! Text to speech!" text = "Hello world! Text to speech!"
print(text_to_sequence(text)) print(text_to_sequence(text))
...@@ -116,7 +118,7 @@ print(text_to_sequence(text)) ...@@ -116,7 +118,7 @@ print(text_to_sequence(text))
# what the pretrained Tacotron2 model expects. ``torchaudio`` provides the # what the pretrained Tacotron2 model expects. ``torchaudio`` provides the
# transform along with the pretrained model. For example, you can # transform along with the pretrained model. For example, you can
# instantiate and use such transform as follow. # instantiate and use such transform as follow.
# #
processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor() processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()
...@@ -132,33 +134,33 @@ print(lengths) ...@@ -132,33 +134,33 @@ print(lengths)
# When a list of texts are provided, the returned ``lengths`` variable # When a list of texts are provided, the returned ``lengths`` variable
# represents the valid length of each processed tokens in the output # represents the valid length of each processed tokens in the output
# batch. # batch.
# #
# The intermediate representation can be retrieved as follow. # The intermediate representation can be retrieved as follow.
# #
print([processor.tokens[i] for i in processed[0, :lengths[0]]]) print([processor.tokens[i] for i in processed[0, : lengths[0]]])
###################################################################### ######################################################################
# Phoneme-based encoding # Phoneme-based encoding
# ~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~
# #
# Phoneme-based encoding is similar to character-based encoding, but it # Phoneme-based encoding is similar to character-based encoding, but it
# uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme) # uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme)
# model. # model.
# #
# The detail of the G2P model is out of scope of this tutorial, we will # The detail of the G2P model is out of scope of this tutorial, we will
# just look at what the conversion looks like. # just look at what the conversion looks like.
# #
# Similar to the case of character-based encoding, the encoding process is # Similar to the case of character-based encoding, the encoding process is
# expected to match what a pretrained Tacotron2 model is trained on. # expected to match what a pretrained Tacotron2 model is trained on.
# ``torchaudio`` has an interface to create the process. # ``torchaudio`` has an interface to create the process.
# #
# The following code illustrates how to make and use the process. Behind # The following code illustrates how to make and use the process. Behind
# the scene, a G2P model is created using ``DeepPhonemizer`` package, and # the scene, a G2P model is created using ``DeepPhonemizer`` package, and
# the pretrained weights published by the author of ``DeepPhonemizer`` is # the pretrained weights published by the author of ``DeepPhonemizer`` is
# fetched. # fetched.
# #
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
...@@ -166,7 +168,7 @@ processor = bundle.get_text_processor() ...@@ -166,7 +168,7 @@ processor = bundle.get_text_processor()
text = "Hello world! Text to speech!" text = "Hello world! Text to speech!"
with torch.inference_mode(): with torch.inference_mode():
processed, lengths = processor(text) processed, lengths = processor(text)
print(processed) print(processed)
print(lengths) print(lengths)
...@@ -175,30 +177,30 @@ print(lengths) ...@@ -175,30 +177,30 @@ print(lengths)
###################################################################### ######################################################################
# Notice that the encoded values are different from the example of # Notice that the encoded values are different from the example of
# character-based encoding. # character-based encoding.
# #
# The intermediate representation looks like the following. # The intermediate representation looks like the following.
# #
print([processor.tokens[i] for i in processed[0, :lengths[0]]]) print([processor.tokens[i] for i in processed[0, : lengths[0]]])
###################################################################### ######################################################################
# Spectrogram Generation # Spectrogram Generation
# ---------------------- # ----------------------
# #
# ``Tacotron2`` is the model we use to generate spectrogram from the # ``Tacotron2`` is the model we use to generate spectrogram from the
# encoded text. For the detail of the model, please refer to `the # encoded text. For the detail of the model, please refer to `the
# paper <https://arxiv.org/abs/1712.05884>`__. # paper <https://arxiv.org/abs/1712.05884>`__.
# #
# It is easy to instantiate a Tacotron2 model with pretrained weight, # It is easy to instantiate a Tacotron2 model with pretrained weight,
# however, note that the input to Tacotron2 models need to be processed # however, note that the input to Tacotron2 models need to be processed
# by the matching text processor. # by the matching text processor.
# #
# :py:func:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching # :py:func:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching
# models and processors together so that it is easy to create the pipeline. # models and processors together so that it is easy to create the pipeline.
# #
# For the available bundles, and its usage, please refer to :py:mod:`torchaudio.pipelines`. # For the available bundles, and its usage, please refer to :py:mod:`torchaudio.pipelines`.
# #
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor() processor = bundle.get_text_processor()
...@@ -207,10 +209,10 @@ tacotron2 = bundle.get_tacotron2().to(device) ...@@ -207,10 +209,10 @@ tacotron2 = bundle.get_tacotron2().to(device)
text = "Hello world! Text to speech!" text = "Hello world! Text to speech!"
with torch.inference_mode(): with torch.inference_mode():
processed, lengths = processor(text) processed, lengths = processor(text)
processed = processed.to(device) processed = processed.to(device)
lengths = lengths.to(device) lengths = lengths.to(device)
spec, _, _ = tacotron2.infer(processed, lengths) spec, _, _ = tacotron2.infer(processed, lengths)
plt.imshow(spec[0].cpu().detach()) plt.imshow(spec[0].cpu().detach())
...@@ -219,36 +221,36 @@ plt.imshow(spec[0].cpu().detach()) ...@@ -219,36 +221,36 @@ plt.imshow(spec[0].cpu().detach())
###################################################################### ######################################################################
# Note that ``Tacotron2.infer`` method perfoms multinomial sampling, # Note that ``Tacotron2.infer`` method perfoms multinomial sampling,
# therefor, the process of generating the spectrogram incurs randomness. # therefor, the process of generating the spectrogram incurs randomness.
# #
fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3)) fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
for i in range(3): for i in range(3):
with torch.inference_mode(): with torch.inference_mode():
spec, spec_lengths, _ = tacotron2.infer(processed, lengths) spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
print(spec[0].shape) print(spec[0].shape)
ax[i].imshow(spec[0].cpu().detach()) ax[i].imshow(spec[0].cpu().detach())
plt.show() plt.show()
###################################################################### ######################################################################
# Waveform Generation # Waveform Generation
# ------------------- # -------------------
# #
# Once the spectrogram is generated, the last process is to recover the # Once the spectrogram is generated, the last process is to recover the
# waveform from the spectrogram. # waveform from the spectrogram.
# #
# ``torchaudio`` provides vocoders based on ``GriffinLim`` and # ``torchaudio`` provides vocoders based on ``GriffinLim`` and
# ``WaveRNN``. # ``WaveRNN``.
# #
###################################################################### ######################################################################
# WaveRNN # WaveRNN
# ~~~~~~~ # ~~~~~~~
# #
# Continuing from the previous section, we can instantiate the matching # Continuing from the previous section, we can instantiate the matching
# WaveRNN model from the same bundle. # WaveRNN model from the same bundle.
# #
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
...@@ -259,27 +261,29 @@ vocoder = bundle.get_vocoder().to(device) ...@@ -259,27 +261,29 @@ vocoder = bundle.get_vocoder().to(device)
text = "Hello world! Text to speech!" text = "Hello world! Text to speech!"
with torch.inference_mode(): with torch.inference_mode():
processed, lengths = processor(text) processed, lengths = processor(text)
processed = processed.to(device) processed = processed.to(device)
lengths = lengths.to(device) lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths) spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths) waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach()) ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate) torchaudio.save(
"_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
IPython.display.Audio("_assets/output_wavernn.wav") IPython.display.Audio("_assets/output_wavernn.wav")
###################################################################### ######################################################################
# Griffin-Lim # Griffin-Lim
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# #
# Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate # Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate
# the vocode object with ``get_vocoder`` method and pass the spectrogram. # the vocode object with ``get_vocoder`` method and pass the spectrogram.
# #
bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH
...@@ -288,34 +292,49 @@ tacotron2 = bundle.get_tacotron2().to(device) ...@@ -288,34 +292,49 @@ tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device) vocoder = bundle.get_vocoder().to(device)
with torch.inference_mode(): with torch.inference_mode():
processed, lengths = processor(text) processed, lengths = processor(text)
processed = processed.to(device) processed = processed.to(device)
lengths = lengths.to(device) lengths = lengths.to(device)
spec, spec_lengths, _ = tacotron2.infer(processed, lengths) spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths) waveforms, lengths = vocoder(spec, spec_lengths)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach()) ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach())
torchaudio.save("_assets/output_griffinlim.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate) torchaudio.save(
"_assets/output_griffinlim.wav",
waveforms[0:1].cpu(),
sample_rate=vocoder.sample_rate,
)
IPython.display.Audio("_assets/output_griffinlim.wav") IPython.display.Audio("_assets/output_griffinlim.wav")
###################################################################### ######################################################################
# Waveglow # Waveglow
# ~~~~~~~~ # ~~~~~~~~
# #
# Waveglow is a vocoder published by Nvidia. The pretrained weight is # Waveglow is a vocoder published by Nvidia. The pretrained weight is
# publishe on Torch Hub. One can instantiate the model using ``torch.hub`` # publishe on Torch Hub. One can instantiate the model using ``torch.hub``
# module. # module.
# #
# Workaround to load model mapped on GPU # Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832 # https://stackoverflow.com/a/61840832
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp32', pretrained=False) waveglow = torch.hub.load(
checkpoint = torch.hub.load_state_dict_from_url('https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth', progress=False, map_location=device) "NVIDIA/DeepLearningExamples:torchhub",
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()} "nvidia_waveglow",
model_math="fp32",
pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501
progress=False,
map_location=device,
)
state_dict = {
key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
waveglow.load_state_dict(state_dict) waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.remove_weightnorm(waveglow)
...@@ -323,7 +342,7 @@ waveglow = waveglow.to(device) ...@@ -323,7 +342,7 @@ waveglow = waveglow.to(device)
waveglow.eval() waveglow.eval()
with torch.no_grad(): with torch.no_grad():
waveforms = waveglow.infer(spec) waveforms = waveglow.infer(spec)
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach()) ax1.imshow(spec[0].cpu().detach())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment