Commit 47716772 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix style to prep #3414 (#3415)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3415

Differential Revision: D46526437

Pulled By: mthrok

fbshipit-source-id: f78d19c19d7e68f67712412de35d9ed50f47263b
parent 91db978b
...@@ -8,7 +8,7 @@ to torchaudio with no labeling responsibility, so we don't want to bother them. ...@@ -8,7 +8,7 @@ to torchaudio with no labeling responsibility, so we don't want to bother them.
import json import json
import os import os
import sys import sys
from typing import Any, Optional, Set, Tuple from typing import Any, Optional, Set
import requests import requests
...@@ -87,7 +87,7 @@ Things in "examples" directory: ...@@ -87,7 +87,7 @@ Things in "examples" directory:
- 'recipe' is applicable to training recipes under the 'examples' folder, - 'recipe' is applicable to training recipes under the 'examples' folder,
- 'tutorial' is applicable to tutorials under the “examples/tutorials” folder - 'tutorial' is applicable to tutorials under the “examples/tutorials” folder
- 'example' is applicable to everything else (e.g. C++ examples) - 'example' is applicable to everything else (e.g. C++ examples)
- 'module: docs' is applicable to code documentations (not to tutorials). \ - 'module: docs' is applicable to code documentations (not to tutorials).
Regarding examples in code documentations, please also use 'module: docs'. Regarding examples in code documentations, please also use 'module: docs'.
......
...@@ -25,6 +25,7 @@ from datetime import datetime ...@@ -25,6 +25,7 @@ from datetime import datetime
sys.path.insert(0, os.path.abspath(".")) sys.path.insert(0, os.path.abspath("."))
import pytorch_sphinx_theme import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
......
...@@ -142,7 +142,7 @@ class ResNet(nn.Module): ...@@ -142,7 +142,7 @@ class ResNet(nn.Module):
) )
) )
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks): for _ in range(1, blocks):
layers.append( layers.append(
block( block(
self.inplanes, self.inplanes,
......
...@@ -174,7 +174,7 @@ class ResNet1D(nn.Module): ...@@ -174,7 +174,7 @@ class ResNet1D(nn.Module):
) )
) )
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks): for _ in range(1, blocks):
layers.append( layers.append(
block( block(
self.inplanes, self.inplanes,
......
...@@ -70,8 +70,8 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) ...@@ -70,8 +70,8 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None _phonemizer = None
available_symbol_set = set(["english_characters", "english_phonemes"]) available_symbol_set = {"english_characters", "english_phonemes"}
available_phonemizers = set(["DeepPhonemizer"]) available_phonemizers = {"DeepPhonemizer"}
def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]: def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
......
...@@ -35,17 +35,14 @@ print(torchaudio.__version__) ...@@ -35,17 +35,14 @@ print(torchaudio.__version__)
# #
try: try:
from torchaudio.prototype.functional import ( from torchaudio.prototype.functional import adsr_envelope, extend_pitch, oscillator_bank
oscillator_bank,
extend_pitch,
adsr_envelope,
)
except ModuleNotFoundError: except ModuleNotFoundError:
print( print(
"Failed to import prototype DSP features. " "Failed to import prototype DSP features. "
"Please install torchaudio nightly builds. " "Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally " "Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.") "for instructions to install a nightly build."
)
raise raise
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -78,7 +75,7 @@ from IPython.display import Audio ...@@ -78,7 +75,7 @@ from IPython.display import Audio
PI = torch.pi PI = torch.pi
PI2 = 2 * torch.pi PI2 = 2 * torch.pi
F0 = 344. # fundamental frequency F0 = 344.0 # fundamental frequency
DURATION = 1.1 # [seconds] DURATION = 1.1 # [seconds]
SAMPLE_RATE = 16_000 # [Hz] SAMPLE_RATE = 16_000 # [Hz]
...@@ -87,26 +84,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -87,26 +84,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
###################################################################### ######################################################################
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
axes[0].plot(t, freq) axes[0].plot(t, freq)
axes[0].set( axes[0].set(title=f"Oscillator bank (bank size: {amp.size(-1)})", ylabel="Frequency [Hz]", ylim=[-0.03, None])
title=f"Oscillator bank (bank size: {amp.size(-1)})",
ylabel="Frequency [Hz]",
ylim=[-0.03, None])
axes[1].plot(t, amp) axes[1].plot(t, amp)
axes[1].set( axes[1].set(ylabel="Amplitude", ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
ylabel="Amplitude",
ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[2].plot(t, waveform) axes[2].plot(t, waveform)
axes[2].set(ylabel="Waveform") axes[2].set(ylabel="Waveform")
axes[3].specgram(waveform, Fs=sample_rate) axes[3].specgram(waveform, Fs=sample_rate)
axes[3].set( axes[3].set(ylabel="Spectrogram", xlabel="Time [s]", xlim=[-0.01, t[-1] + 0.01])
ylabel="Spectrogram",
xlabel="Time [s]",
xlim=[-0.01, t[-1] + 0.01])
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
...@@ -121,6 +111,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -121,6 +111,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
waveform /= waveform.abs().max() waveform /= waveform.abs().max()
return Audio(vol * waveform, rate=sample_rate, normalize=False) return Audio(vol * waveform, rate=sample_rate, normalize=False)
###################################################################### ######################################################################
# Harmonic Overtones # Harmonic Overtones
# ------------------- # -------------------
...@@ -159,10 +150,11 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -159,10 +150,11 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
# and adds extend pitch in accordance with the formula above. # and adds extend pitch in accordance with the formula above.
# #
def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq = extend_pitch(freq0, num_pitches) freq = extend_pitch(freq0, num_pitches)
mults = [-((-1) ** i) / (PI * i) for i in range(1, 1+num_pitches)] mults = [-((-1) ** i) / (PI * i) for i in range(1, 1 + num_pitches)]
amp = extend_pitch(amp0, mults) amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate) waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
return freq, amp, waveform return freq, amp, waveform
...@@ -176,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -176,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# #
...@@ -191,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES) ...@@ -191,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Square wave # Square wave
...@@ -212,22 +204,23 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) ...@@ -212,22 +204,23 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
def square_wave(freq0, amp0, num_pitches, sample_rate): def square_wave(freq0, amp0, num_pitches, sample_rate):
mults = [2. * i + 1. for i in range(num_pitches)] mults = [2.0 * i + 1.0 for i in range(num_pitches)]
freq = extend_pitch(freq0, mults) freq = extend_pitch(freq0, mults)
mults = [4 / (PI * (2. * i + 1.)) for i in range(num_pitches)] mults = [4 / (PI * (2.0 * i + 1.0)) for i in range(num_pitches)]
amp = extend_pitch(amp0, mults) amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate) waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
return freq, amp, waveform return freq, amp, waveform
###################################################################### ######################################################################
# #
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE/F0/2), SAMPLE_RATE) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Triangle wave # Triangle wave
...@@ -248,11 +241,11 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) ...@@ -248,11 +241,11 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
def triangle_wave(freq0, amp0, num_pitches, sample_rate): def triangle_wave(freq0, amp0, num_pitches, sample_rate):
mults = [2. * i + 1. for i in range(num_pitches)] mults = [2.0 * i + 1.0 for i in range(num_pitches)]
freq = extend_pitch(freq0, mults) freq = extend_pitch(freq0, mults)
c = 8 / (PI ** 2) c = 8 / (PI**2)
mults = [c * ((-1) ** i) / ((2. * i + 1.) ** 2) for i in range(num_pitches)] mults = [c * ((-1) ** i) / ((2.0 * i + 1.0) ** 2) for i in range(num_pitches)]
amp = extend_pitch(amp0, mults) amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate) waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
...@@ -263,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -263,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
# #
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Inharmonic Paritials # Inharmonic Paritials
...@@ -288,18 +281,18 @@ duration = 2.0 ...@@ -288,18 +281,18 @@ duration = 2.0
num_frames = int(SAMPLE_RATE * duration) num_frames = int(SAMPLE_RATE * duration)
freq0 = torch.full((num_frames, 1), F0) freq0 = torch.full((num_frames, 1), F0)
mults = [0.56, 0.92, 1.19, 1.71, 2, 2.74, 3., 3.76, 4.07] mults = [0.56, 0.92, 1.19, 1.71, 2, 2.74, 3.0, 3.76, 4.07]
freq = extend_pitch(freq0, mults) freq = extend_pitch(freq0, mults)
amp = adsr_envelope( amp = adsr_envelope(
num_frames=num_frames, num_frames=num_frames,
attack=0.002, attack=0.002,
decay=0.998, decay=0.998,
sustain=0., sustain=0.0,
release=0., release=0.0,
n_decay=2, n_decay=2,
) )
amp = torch.stack([amp * (0.5 ** i) for i in range(num_tones)], dim=-1) amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
......
...@@ -207,6 +207,7 @@ from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState ...@@ -207,6 +207,7 @@ from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM): class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder.""" """Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module): def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self) CTCDecoderLM.__init__(self)
self.language_model = language_model self.language_model = language_model
......
...@@ -20,6 +20,8 @@ import torchaudio.functional as F ...@@ -20,6 +20,8 @@ import torchaudio.functional as F
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
...@@ -28,7 +30,6 @@ print(torchaudio.__version__) ...@@ -28,7 +30,6 @@ print(torchaudio.__version__)
# #
from IPython.display import Audio from IPython.display import Audio
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
...@@ -101,6 +102,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): ...@@ -101,6 +102,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) plt.show(block=False)
###################################################################### ######################################################################
# #
...@@ -122,6 +124,7 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): ...@@ -122,6 +124,7 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) plt.show(block=False)
###################################################################### ######################################################################
# Original # Original
# ~~~~~~~~ # ~~~~~~~~
...@@ -353,10 +356,12 @@ bg_added = F.add_noise(rir_applied, noise, snr_db) ...@@ -353,10 +356,12 @@ bg_added = F.add_noise(rir_applied, noise, snr_db)
plot_specgram(bg_added, sample_rate, title="BG noise added") plot_specgram(bg_added, sample_rate, title="BG noise added")
# Apply filtering and change sample rate # Apply filtering and change sample rate
effect = ",".join([ effect = ",".join(
[
"lowpass=frequency=4000:poles=1", "lowpass=frequency=4000:poles=1",
"compand=attacks=0.02:decays=0.05:points=-60/-60|-30/-10|-20/-8|-5/-8|-2/-8:gain=-8:volume=-7:delay=0.05", "compand=attacks=0.02:decays=0.05:points=-60/-60|-30/-10|-20/-8|-5/-8|-2/-8:gain=-8:volume=-7:delay=0.05",
]) ]
)
filtered = apply_effect(bg_added.T, sample_rate, effect) filtered = apply_effect(bg_added.T, sample_rate, effect)
sample_rate2 = 8000 sample_rate2 = 8000
......
...@@ -25,6 +25,9 @@ import torchaudio.transforms as T ...@@ -25,6 +25,9 @@ import torchaudio.transforms as T
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
import librosa
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
...@@ -38,8 +41,6 @@ print(torchaudio.__version__) ...@@ -38,8 +41,6 @@ print(torchaudio.__version__)
# !pip install librosa # !pip install librosa
# #
from IPython.display import Audio from IPython.display import Audio
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -388,6 +389,7 @@ pitch = F.detect_pitch_frequency(SPEECH_WAVEFORM, SAMPLE_RATE) ...@@ -388,6 +389,7 @@ pitch = F.detect_pitch_frequency(SPEECH_WAVEFORM, SAMPLE_RATE)
###################################################################### ######################################################################
# #
def plot_pitch(waveform, sr, pitch): def plot_pitch(waveform, sr, pitch):
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature") axis.set_title("Pitch Feature")
......
...@@ -61,6 +61,7 @@ def _hide_seek(obj): ...@@ -61,6 +61,7 @@ def _hide_seek(obj):
def read(self, n): def read(self, n):
return self.obj.read(n) return self.obj.read(n)
return _wrapper(obj) return _wrapper(obj)
...@@ -294,7 +295,8 @@ with requests.get(url, stream=True) as response: ...@@ -294,7 +295,8 @@ with requests.get(url, stream=True) as response:
print("Fetching until the requested frames are available...") print("Fetching until the requested frames are available...")
with requests.get(url, stream=True) as response: with requests.get(url, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load( waveform2, sample_rate2 = torchaudio.load(
_hide_seek(response.raw), frame_offset=frame_offset, num_frames=num_frames) _hide_seek(response.raw), frame_offset=frame_offset, num_frames=num_frames
)
print(f" - Fetched {response.raw.tell()} bytes") print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="") print("Checking the resulting waveform ... ", end="")
...@@ -333,6 +335,7 @@ waveform, sample_rate = torchaudio.load(SAMPLE_WAV) ...@@ -333,6 +335,7 @@ waveform, sample_rate = torchaudio.load(SAMPLE_WAV)
###################################################################### ######################################################################
# #
def inspect_file(path): def inspect_file(path):
print("-" * 10) print("-" * 10)
print("Source:", path) print("Source:", path)
...@@ -341,6 +344,7 @@ def inspect_file(path): ...@@ -341,6 +344,7 @@ def inspect_file(path):
print(f" - {torchaudio.info(path)}") print(f" - {torchaudio.info(path)}")
print() print()
###################################################################### ######################################################################
# #
# Save without any encoding option. # Save without any encoding option.
......
...@@ -27,14 +27,14 @@ import math ...@@ -27,14 +27,14 @@ import math
import timeit import timeit
import librosa import librosa
import resampy
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from IPython.display import Audio, display import resampy
from IPython.display import Audio
pd.set_option('display.max_rows', None) pd.set_option("display.max_rows", None)
pd.set_option('display.max_columns', None) pd.set_option("display.max_columns", None)
DEFAULT_OFFSET = 201 DEFAULT_OFFSET = 201
...@@ -338,6 +338,7 @@ print(f"resampy: {resampy.__version__}") ...@@ -338,6 +338,7 @@ print(f"resampy: {resampy.__version__}")
###################################################################### ######################################################################
# #
def benchmark_resample_functional( def benchmark_resample_functional(
waveform, waveform,
sample_rate, sample_rate,
...@@ -348,8 +349,9 @@ def benchmark_resample_functional( ...@@ -348,8 +349,9 @@ def benchmark_resample_functional(
beta=None, beta=None,
iters=5, iters=5,
): ):
return timeit.timeit( return (
stmt=''' timeit.timeit(
stmt="""
torchaudio.functional.resample( torchaudio.functional.resample(
waveform, waveform,
sample_rate, sample_rate,
...@@ -359,16 +361,20 @@ torchaudio.functional.resample( ...@@ -359,16 +361,20 @@ torchaudio.functional.resample(
resampling_method=resampling_method, resampling_method=resampling_method,
beta=beta, beta=beta,
) )
''', """,
setup='import torchaudio', setup="import torchaudio",
number=iters, number=iters,
globals=locals(), globals=locals(),
) * 1000 / iters )
* 1000
/ iters
)
###################################################################### ######################################################################
# #
def benchmark_resample_transforms( def benchmark_resample_transforms(
waveform, waveform,
sample_rate, sample_rate,
...@@ -379,9 +385,10 @@ def benchmark_resample_transforms( ...@@ -379,9 +385,10 @@ def benchmark_resample_transforms(
beta=None, beta=None,
iters=5, iters=5,
): ):
return timeit.timeit( return (
stmt='resampler(waveform)', timeit.timeit(
setup=''' stmt="resampler(waveform)",
setup="""
import torchaudio import torchaudio
resampler = torchaudio.transforms.Resample( resampler = torchaudio.transforms.Resample(
...@@ -394,15 +401,19 @@ resampler = torchaudio.transforms.Resample( ...@@ -394,15 +401,19 @@ resampler = torchaudio.transforms.Resample(
beta=beta, beta=beta,
) )
resampler.to(waveform.device) resampler.to(waveform.device)
''', """,
number=iters, number=iters,
globals=locals(), globals=locals(),
) * 1000 / iters )
* 1000
/ iters
)
###################################################################### ######################################################################
# #
def benchmark_resample_librosa( def benchmark_resample_librosa(
waveform, waveform,
sample_rate, sample_rate,
...@@ -411,24 +422,29 @@ def benchmark_resample_librosa( ...@@ -411,24 +422,29 @@ def benchmark_resample_librosa(
iters=5, iters=5,
): ):
waveform_np = waveform.squeeze().numpy() waveform_np = waveform.squeeze().numpy()
return timeit.timeit( return (
stmt=''' timeit.timeit(
stmt="""
librosa.resample( librosa.resample(
waveform_np, waveform_np,
orig_sr=sample_rate, orig_sr=sample_rate,
target_sr=resample_rate, target_sr=resample_rate,
res_type=res_type, res_type=res_type,
) )
''', """,
setup='import librosa', setup="import librosa",
number=iters, number=iters,
globals=locals(), globals=locals(),
) * 1000 / iters )
* 1000
/ iters
)
###################################################################### ######################################################################
# #
def benchmark(sample_rate, resample_rate): def benchmark(sample_rate, resample_rate):
times, rows = [], [] times, rows = [], []
waveform = get_sine_sweep(sample_rate).to(torch.float32) waveform = get_sine_sweep(sample_rate).to(torch.float32)
...@@ -483,7 +499,7 @@ def plot(df): ...@@ -483,7 +499,7 @@ def plot(df):
print(df.round(2)) print(df.round(2))
ax = df.plot(kind="bar") ax = df.plot(kind="bar")
plt.ylabel("Time Elapsed [ms]") plt.ylabel("Time Elapsed [ms]")
plt.xticks(rotation = 0, fontsize=10) plt.xticks(rotation=0, fontsize=10)
for cont, col, color in zip(ax.containers, df.columns, mcolors.TABLEAU_COLORS): for cont, col, color in zip(ax.containers, df.columns, mcolors.TABLEAU_COLORS):
label = ["N/A" if v != v else str(v) for v in df[col].round(2)] label = ["N/A" if v != v else str(v) for v in df[col].round(2)]
ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small") ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small")
......
...@@ -39,14 +39,10 @@ except ModuleNotFoundError: ...@@ -39,14 +39,10 @@ except ModuleNotFoundError:
"Failed to import the forced alignment API. " "Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. " "Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally " "Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.") "for instructions to install a nightly build."
)
raise raise
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Audio
###################################################################### ######################################################################
# I. Basic usages # I. Basic usages
# --------------- # ---------------
...@@ -71,7 +67,10 @@ from IPython.display import Audio ...@@ -71,7 +67,10 @@ from IPython.display import Audio
# %matplotlib inline # %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
...@@ -193,8 +192,6 @@ plt.show() ...@@ -193,8 +192,6 @@ plt.show()
# token-level and word-level alignments easily. # token-level and word-level alignments easily.
# #
import torchaudio.functional as F
@dataclass @dataclass
class Frame: class Frame:
...@@ -214,7 +211,7 @@ def compute_alignments(transcript, dictionary, emission): ...@@ -214,7 +211,7 @@ def compute_alignments(transcript, dictionary, emission):
target_lengths = torch.tensor(targets.shape[0]) target_lengths = torch.tensor(targets.shape[0])
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments. # This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = F.forced_align(emission, targets, input_lengths, target_lengths, 0) frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert len(frame_alignment) == input_lengths.item() assert len(frame_alignment) == input_lengths.item()
assert len(targets) == target_lengths.item() assert len(targets) == target_lengths.item()
...@@ -279,8 +276,6 @@ def merge_repeats(frames, transcript): ...@@ -279,8 +276,6 @@ def merge_repeats(frames, transcript):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index: while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1 i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1) score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
tokens = [dictionary[c] if c in dictionary else dictionary['@'] for c in transcript.replace(" ", "")]
segments.append( segments.append(
Segment( Segment(
transcript_nospace[frames[i1].token_index], transcript_nospace[frames[i1].token_index],
...@@ -370,7 +365,7 @@ def merge_words(transcript, segments, separator=" "): ...@@ -370,7 +365,7 @@ def merge_words(transcript, segments, separator=" "):
s = len(words) s = len(words)
else: else:
s = 0 s = 0
segs = segments[i1 + s:i2 + s] segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs]) word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score)) words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
...@@ -380,6 +375,7 @@ def merge_words(transcript, segments, separator=" "): ...@@ -380,6 +375,7 @@ def merge_words(transcript, segments, separator=" "):
i3 += 1 i3 += 1
return words return words
word_segments = merge_words(transcript, segments, "|") word_segments = merge_words(transcript, segments, "|")
...@@ -388,9 +384,10 @@ word_segments = merge_words(transcript, segments, "|") ...@@ -388,9 +384,10 @@ word_segments = merge_words(transcript, segments, "|")
# ^^^^^^^^^^^^^ # ^^^^^^^^^^^^^
# #
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12)) fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({'font.size': 30}) plt.rcParams.update({"font.size": 30})
# The original waveform # The original waveform
ratio = waveform.size(0) / input_lengths ratio = waveform.size(0) / input_lengths
...@@ -413,6 +410,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): ...@@ -413,6 +410,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
ax2.set_xlabel("time [second]", fontsize=40) ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([]) ax2.set_yticks([])
plot_alignments( plot_alignments(
segments, segments,
word_segments, word_segments,
...@@ -437,6 +435,7 @@ def display_segment(i, waveform, word_segments, frame_alignment): ...@@ -437,6 +435,7 @@ def display_segment(i, waveform, word_segments, frame_alignment):
segment = waveform[:, x0:x1] segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
# Generate the audio for each segment # Generate the audio for each segment
print(transcript) print(transcript)
IPython.display.Audio(SPEECH_FILE) IPython.display.Audio(SPEECH_FILE)
...@@ -532,7 +531,9 @@ model = wav2vec2_model( ...@@ -532,7 +531,9 @@ model = wav2vec2_model(
aux_num_out=31, aux_num_out=31,
) )
torch.hub.download_url_to_file("https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", "model.pt") torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", "model.pt"
)
checkpoint = torch.load("model.pt", map_location="cpu") checkpoint = torch.load("model.pt", map_location="cpu")
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
...@@ -556,6 +557,7 @@ def get_emission(waveform): ...@@ -556,6 +557,7 @@ def get_emission(waveform):
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
return emission, waveform return emission, waveform
emission, waveform = get_emission(waveform) emission, waveform = get_emission(waveform)
# Construct the dictionary # Construct the dictionary
...@@ -602,15 +604,11 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform): ...@@ -602,15 +604,11 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission) frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
segments = merge_repeats(frames, transcript) segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments) word_segments = merge_words(transcript, segments)
plot_alignments( plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
segments,
word_segments,
waveform[0],
emission.shape[0]
)
plt.show() plt.show()
return word_segments, frame_alignment return word_segments, frame_alignment
# One can follow the following steps to download the uroman romanizer and use it to obtain normalized transcripts. # One can follow the following steps to download the uroman romanizer and use it to obtain normalized transcripts.
# def normalize_uroman(text): # def normalize_uroman(text):
# text = text.lower() # text = text.lower()
...@@ -629,7 +627,9 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform): ...@@ -629,7 +627,9 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
# text_normalized = normalize_uroman(lines[0].strip()) # text_normalized = normalize_uroman(lines[0].strip())
text_normalized = "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid" text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac")
waveform, _ = torchaudio.load(SPEECH_FILE) waveform, _ = torchaudio.load(SPEECH_FILE)
......
...@@ -60,10 +60,7 @@ try: ...@@ -60,10 +60,7 @@ try:
for k, v in torchaudio.utils.ffmpeg_utils.get_versions().items(): for k, v in torchaudio.utils.ffmpeg_utils.get_versions().items():
print(k, v) print(k, v)
except Exception: except Exception:
raise RuntimeError( raise RuntimeError("This tutorial requires FFmpeg libraries 4.2>,<5. " "Please install FFmpeg.")
"This tutorial requires FFmpeg libraries 4.2>,<5. "
"Please install FFmpeg."
)
###################################################################### ######################################################################
# Usage # Usage
...@@ -107,11 +104,11 @@ waveform, sr = torchaudio.load(src, channels_first=False) ...@@ -107,11 +104,11 @@ waveform, sr = torchaudio.load(src, channels_first=False)
# #
def show(effect=None, format=None, *, stereo=False): def show(effect, *, stereo=False):
wf = torch.cat([waveform] * 2, dim=1) if stereo else waveform wf = torch.cat([waveform] * 2, dim=1) if stereo else waveform
figsize = (6.4, 2.1 if stereo else 1.2) figsize = (6.4, 2.1 if stereo else 1.2)
effector = AudioEffector(effect=effect, format=format, pad_end=False) effector = AudioEffector(effect=effect, pad_end=False)
result = effector.apply(wf, int(sr)) result = effector.apply(wf, int(sr))
num_channels = result.size(1) num_channels = result.size(1)
...@@ -128,7 +125,7 @@ def show(effect=None, format=None, *, stereo=False): ...@@ -128,7 +125,7 @@ def show(effect=None, format=None, *, stereo=False):
# -------- # --------
# #
show(effect=None, format=None) show(effect=None)
###################################################################### ######################################################################
# Effects # Effects
...@@ -139,131 +136,138 @@ show(effect=None, format=None) ...@@ -139,131 +136,138 @@ show(effect=None, format=None)
# tempo # tempo
# ~~~~~ # ~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#atempo # https://ffmpeg.org/ffmpeg-filters.html#atempo
show(effect="atempo=0.7") show("atempo=0.7")
###################################################################### ######################################################################
# #
show(effect="atempo=1.8") show("atempo=1.8")
###################################################################### ######################################################################
# highpass # highpass
# ~~~~~~~~ # ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#highpass # https://ffmpeg.org/ffmpeg-filters.html#highpass
show(effect="highpass=frequency=1500") show("highpass=frequency=1500")
###################################################################### ######################################################################
# lowpass # lowpass
# ~~~~~~~ # ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#lowpass # https://ffmpeg.org/ffmpeg-filters.html#lowpass
show(effect="lowpass=frequency=1000") show("lowpass=frequency=1000")
###################################################################### ######################################################################
# allpass # allpass
# ~~~~~~~~ # ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#allpass # https://ffmpeg.org/ffmpeg-filters.html#allpass
show(effect="allpass") show("allpass")
###################################################################### ######################################################################
# bandpass # bandpass
# ~~~~~~~~ # ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#bandpass # https://ffmpeg.org/ffmpeg-filters.html#bandpass
show(effect="bandpass=frequency=3000") show("bandpass=frequency=3000")
###################################################################### ######################################################################
# bandreject # bandreject
# ~~~~~~~~~~ # ~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#bandreject # https://ffmpeg.org/ffmpeg-filters.html#bandreject
show(effect="bandreject=frequency=3000") show("bandreject=frequency=3000")
###################################################################### ######################################################################
# echo # echo
# ~~~~ # ~~~~
# https://ffmpeg.org/ffmpeg-filters.html#aecho # https://ffmpeg.org/ffmpeg-filters.html#aecho
show(effect="aecho=in_gain=0.8:out_gain=0.88:delays=6:decays=0.4") show("aecho=in_gain=0.8:out_gain=0.88:delays=6:decays=0.4")
###################################################################### ######################################################################
# #
show(effect="aecho=in_gain=0.8:out_gain=0.88:delays=60:decays=0.4") show("aecho=in_gain=0.8:out_gain=0.88:delays=60:decays=0.4")
###################################################################### ######################################################################
# #
show(effect="aecho=in_gain=0.8:out_gain=0.9:delays=1000:decays=0.3") show("aecho=in_gain=0.8:out_gain=0.9:delays=1000:decays=0.3")
###################################################################### ######################################################################
# chorus # chorus
# ~~~~~~ # ~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#chorus # https://ffmpeg.org/ffmpeg-filters.html#chorus
show(effect=("chorus=0.5:0.9:50|60|40:0.4|0.32|0.3:0.25|0.4|0.3:2|2.3|1.3")) show("chorus=0.5:0.9:50|60|40:0.4|0.32|0.3:0.25|0.4|0.3:2|2.3|1.3")
###################################################################### ######################################################################
# fft filter # fft filter
# ~~~~~~~~~~ # ~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#afftfilt # https://ffmpeg.org/ffmpeg-filters.html#afftfilt
show(effect=(
# fmt: off
show(
"afftfilt=" "afftfilt="
"real='re * (1-clip(b * (b/nb), 0, 1))':" "real='re * (1-clip(b * (b/nb), 0, 1))':"
"imag='im * (1-clip(b * (b/nb), 0, 1))'")) "imag='im * (1-clip(b * (b/nb), 0, 1))'"
)
###################################################################### ######################################################################
# #
show(effect=(
show(
"afftfilt=" "afftfilt="
"real='hypot(re,im) * sin(0)':" "real='hypot(re,im) * sin(0)':"
"imag='hypot(re,im) * cos(0)':" "imag='hypot(re,im) * cos(0)':"
"win_size=512:" "win_size=512:"
"overlap=0.75")) "overlap=0.75"
)
###################################################################### ######################################################################
# #
show(effect=(
show(
"afftfilt=" "afftfilt="
"real='hypot(re,im) * cos(2 * 3.14 * (random(0) * 2-1))':" "real='hypot(re,im) * cos(2 * 3.14 * (random(0) * 2-1))':"
"imag='hypot(re,im) * sin(2 * 3.14 * (random(1) * 2-1))':" "imag='hypot(re,im) * sin(2 * 3.14 * (random(1) * 2-1))':"
"win_size=128:" "win_size=128:"
"overlap=0.8")) "overlap=0.8"
)
# fmt: on
###################################################################### ######################################################################
# vibrato # vibrato
# ~~~~~~~ # ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#vibrato # https://ffmpeg.org/ffmpeg-filters.html#vibrato
show(effect=("vibrato=f=10:d=0.8")) show("vibrato=f=10:d=0.8")
###################################################################### ######################################################################
# tremolo # tremolo
# ~~~~~~~ # ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#tremolo # https://ffmpeg.org/ffmpeg-filters.html#tremolo
show(effect=("tremolo=f=8:d=0.8")) show("tremolo=f=8:d=0.8")
###################################################################### ######################################################################
# crystalizer # crystalizer
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#crystalizer # https://ffmpeg.org/ffmpeg-filters.html#crystalizer
show(effect=("crystalizer")) show("crystalizer")
###################################################################### ######################################################################
# flanger # flanger
# ~~~~~~~ # ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#flanger # https://ffmpeg.org/ffmpeg-filters.html#flanger
show(effect=("flanger")) show("flanger")
###################################################################### ######################################################################
# phaser # phaser
# ~~~~~~ # ~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#aphaser # https://ffmpeg.org/ffmpeg-filters.html#aphaser
show(effect=("aphaser")) show("aphaser")
###################################################################### ######################################################################
# pulsator # pulsator
# ~~~~~~~~ # ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#apulsator # https://ffmpeg.org/ffmpeg-filters.html#apulsator
show(effect=("apulsator"), stereo=True) show("apulsator", stereo=True)
###################################################################### ######################################################################
# haas # haas
# ~~~~ # ~~~~
# https://ffmpeg.org/ffmpeg-filters.html#haas # https://ffmpeg.org/ffmpeg-filters.html#haas
show(effect=("haas")) show("haas")
###################################################################### ######################################################################
# Codecs # Codecs
...@@ -292,11 +296,13 @@ def show_multi(configs): ...@@ -292,11 +296,13 @@ def show_multi(configs):
# ~~~ # ~~~
# #
results = show_multi([ results = show_multi(
[
{"format": "ogg"}, {"format": "ogg"},
{"format": "ogg", "encoder": "vorbis"}, {"format": "ogg", "encoder": "vorbis"},
{"format": "ogg", "encoder": "opus"}, {"format": "ogg", "encoder": "opus"},
]) ]
)
###################################################################### ######################################################################
# ogg - default encoder (flac) # ogg - default encoder (flac)
...@@ -321,7 +327,8 @@ results[2] ...@@ -321,7 +327,8 @@ results[2]
# ~~~ # ~~~
# https://trac.ffmpeg.org/wiki/Encode/MP3 # https://trac.ffmpeg.org/wiki/Encode/MP3
results = show_multi([ results = show_multi(
[
{"format": "mp3"}, {"format": "mp3"},
{"format": "mp3", "codec_config": CodecConfig(compression_level=1)}, {"format": "mp3", "codec_config": CodecConfig(compression_level=1)},
{"format": "mp3", "codec_config": CodecConfig(compression_level=9)}, {"format": "mp3", "codec_config": CodecConfig(compression_level=9)},
...@@ -329,7 +336,8 @@ results = show_multi([ ...@@ -329,7 +336,8 @@ results = show_multi([
{"format": "mp3", "codec_config": CodecConfig(bit_rate=8_000)}, {"format": "mp3", "codec_config": CodecConfig(bit_rate=8_000)},
{"format": "mp3", "codec_config": CodecConfig(qscale=9)}, {"format": "mp3", "codec_config": CodecConfig(qscale=9)},
{"format": "mp3", "codec_config": CodecConfig(qscale=1)}, {"format": "mp3", "codec_config": CodecConfig(qscale=1)},
]) ]
)
###################################################################### ######################################################################
# default # default
......
...@@ -27,14 +27,11 @@ import torchaudio ...@@ -27,14 +27,11 @@ import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# #
from torchaudio.prototype.functional import ( from torchaudio.prototype.functional import frequency_impulse_response, sinc_impulse_response
sinc_impulse_response,
frequency_impulse_response,
)
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# #
...@@ -75,7 +72,7 @@ import matplotlib.pyplot as plt ...@@ -75,7 +72,7 @@ import matplotlib.pyplot as plt
# :py:func:`~torchaudio.prototype.functional.sinc_impulse_response`. # :py:func:`~torchaudio.prototype.functional.sinc_impulse_response`.
# #
cutoff = torch.linspace(0., 1., 9) cutoff = torch.linspace(0.0, 1.0, 9)
irs = sinc_impulse_response(cutoff, window_size=513) irs = sinc_impulse_response(cutoff, window_size=513)
print("Cutoff shape:", cutoff.shape) print("Cutoff shape:", cutoff.shape)
...@@ -87,6 +84,7 @@ print("Impulse response shape:", irs.shape) ...@@ -87,6 +84,7 @@ print("Impulse response shape:", irs.shape)
# Let's visualize the resulting impulse responses. # Let's visualize the resulting impulse responses.
# #
def plot_sinc_ir(irs, cutoff): def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape num_filts, window_size = irs.shape
half = window_size // 2 half = window_size // 2
...@@ -99,7 +97,8 @@ def plot_sinc_ir(irs, cutoff): ...@@ -99,7 +97,8 @@ def plot_sinc_ir(irs, cutoff):
ax.grid(True) ax.grid(True)
fig.suptitle( fig.suptitle(
"Impulse response of sinc low-pass filter for different cut-off frequencies\n" "Impulse response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)") "(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout() plt.tight_layout()
...@@ -126,12 +125,12 @@ frs = torch.fft.rfft(irs, n=2048, dim=1).abs() ...@@ -126,12 +125,12 @@ frs = torch.fft.rfft(irs, n=2048, dim=1).abs()
# Let's visualize the resulting frequency responses. # Let's visualize the resulting frequency responses.
# #
def plot_sinc_fr(frs, cutoff, band=False): def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots( fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True) ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -141,11 +140,12 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -141,11 +140,12 @@ def plot_sinc_fr(frs, cutoff, band=False):
yticks=[1e-9, 1e-6, 1e-3, 1], yticks=[1e-9, 1e-6, 1e-3, 1],
xticks=torch.linspace(0, num_fft, num_ticks), xticks=torch.linspace(0, num_fft, num_ticks),
xticklabels=[f"{i/(num_ticks - 1)}" for i in range(num_ticks)], xticklabels=[f"{i/(num_ticks - 1)}" for i in range(num_ticks)],
xlabel="Frequency" xlabel="Frequency",
) )
fig.suptitle( fig.suptitle(
"Frequency response of sinc low-pass filter for different cut-off frequencies\n" "Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)") "(Frequencies are relative to Nyquist frequency)"
)
plt.tight_layout() plt.tight_layout()
...@@ -193,13 +193,11 @@ plot_sinc_fr(frs, cutoff) ...@@ -193,13 +193,11 @@ plot_sinc_fr(frs, cutoff)
# Band-pass filter can be obtained by subtracting low-pass filter for # Band-pass filter can be obtained by subtracting low-pass filter for
# upper band from that of lower band. # upper band from that of lower band.
cutoff = torch.linspace(0., 1, 11) cutoff = torch.linspace(0.0, 1, 11)
c_low = cutoff[:-1] c_low = cutoff[:-1]
c_high = cutoff[1:] c_high = cutoff[1:]
irs = ( irs = sinc_impulse_response(c_low, window_size=513) - sinc_impulse_response(c_high, window_size=513)
sinc_impulse_response(c_low, window_size=513)
- sinc_impulse_response(c_high, window_size=513))
frs = torch.fft.rfft(irs, n=2048, dim=1).abs() frs = torch.fft.rfft(irs, n=2048, dim=1).abs()
###################################################################### ######################################################################
...@@ -256,6 +254,7 @@ print("Impulse Response:", ir.shape) ...@@ -256,6 +254,7 @@ print("Impulse Response:", ir.shape)
###################################################################### ######################################################################
# #
def plot_ir(magnitudes, ir, num_fft=2048): def plot_ir(magnitudes, ir, num_fft=2048):
fr = torch.fft.rfft(ir, n=num_fft, dim=0).abs() fr = torch.fft.rfft(ir, n=num_fft, dim=0).abs()
ir_size = ir.size(-1) ir_size = ir.size(-1)
...@@ -268,17 +267,18 @@ def plot_ir(magnitudes, ir, num_fft=2048): ...@@ -268,17 +267,18 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[0].set(title="Impulse Response") axes[0].set(title="Impulse Response")
axes[0].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[0].set_xticks([i * half // 4 for i in range(-4, 5)])
t = torch.linspace(0, 1, fr.numel()) t = torch.linspace(0, 1, fr.numel())
axes[1].plot(t, fr, label='Actual') axes[1].plot(t, fr, label="Actual")
axes[2].semilogy(t, fr, label='Actual') axes[2].semilogy(t, fr, label="Actual")
t = torch.linspace(0, 1, magnitudes.numel()) t = torch.linspace(0, 1, magnitudes.numel())
for i in range(1, 3): for i in range(1, 3):
axes[i].plot(t, magnitudes, label='Desired (input)', linewidth=1.1, linestyle='--') axes[i].plot(t, magnitudes, label="Desired (input)", linewidth=1.1, linestyle="--")
axes[i].grid(True) axes[i].grid(True)
axes[1].set(title="Frequency Response") axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right") axes[2].legend(loc="lower right")
fig.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
# #
...@@ -305,7 +305,7 @@ plot_ir(magnitudes, ir) ...@@ -305,7 +305,7 @@ plot_ir(magnitudes, ir)
# #
# #
magnitudes = torch.linspace(0, 1, 64)**4.0 magnitudes = torch.linspace(0, 1, 64) ** 4.0
ir = frequency_impulse_response(magnitudes) ir = frequency_impulse_response(magnitudes)
...@@ -316,7 +316,7 @@ plot_ir(magnitudes, ir) ...@@ -316,7 +316,7 @@ plot_ir(magnitudes, ir)
###################################################################### ######################################################################
# #
magnitudes = torch.sin(torch.linspace(0, 10, 64))**4.0 magnitudes = torch.sin(torch.linspace(0, 10, 64)) ** 4.0
ir = frequency_impulse_response(magnitudes) ir = frequency_impulse_response(magnitudes)
......
...@@ -450,6 +450,7 @@ plot_alignments( ...@@ -450,6 +450,7 @@ plot_alignments(
) )
plt.show() plt.show()
################################################################################ ################################################################################
# #
......
...@@ -45,6 +45,8 @@ import torchaudio ...@@ -45,6 +45,8 @@ import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# In addition to ``torchaudio``, ``mir_eval`` is required to perform # In addition to ``torchaudio``, ``mir_eval`` is required to perform
# signal-to-distortion ratio (SDR) calculations. To install ``mir_eval`` # signal-to-distortion ratio (SDR) calculations. To install ``mir_eval``
...@@ -52,30 +54,9 @@ print(torchaudio.__version__) ...@@ -52,30 +54,9 @@ print(torchaudio.__version__)
# #
from IPython.display import Audio from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
import matplotlib.pyplot as plt
try:
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from mir_eval import separation
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install mir_eval
"""
)
except ModuleNotFoundError:
pass
raise
###################################################################### ######################################################################
# 3. Construct the pipeline # 3. Construct the pipeline
...@@ -132,7 +113,7 @@ from torchaudio.transforms import Fade ...@@ -132,7 +113,7 @@ from torchaudio.transforms import Fade
def separate_sources( def separate_sources(
model, model,
mix, mix,
segment=10., segment=10.0,
overlap=0.1, overlap=0.1,
device=None, device=None,
): ):
...@@ -157,7 +138,7 @@ def separate_sources( ...@@ -157,7 +138,7 @@ def separate_sources(
start = 0 start = 0
end = chunk_len end = chunk_len
overlap_frames = overlap * sample_rate overlap_frames = overlap * sample_rate
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear') fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
final = torch.zeros(batch, len(model.sources), channels, length, device=device) final = torch.zeros(batch, len(model.sources), channels, length, device=device)
...@@ -265,12 +246,13 @@ stft = torchaudio.transforms.Spectrogram( ...@@ -265,12 +246,13 @@ stft = torchaudio.transforms.Spectrogram(
# scores. # scores.
# #
def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str): def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
print("SDR score is:", print(
separation.bss_eval_sources( "SDR score is:",
original_source.detach().numpy(), separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
predicted_source.detach().numpy())[0].mean()) )
plot_spectrogram(stft(predicted_source)[0], f'Spectrogram {source}') plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
return Audio(predicted_source, rate=sample_rate) return Audio(predicted_source, rate=sample_rate)
...@@ -285,19 +267,19 @@ bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav") ...@@ -285,19 +267,19 @@ bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav") vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav") other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")
drums_spec = audios["drums"][:, frame_start: frame_end].cpu() drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original) drums, sample_rate = torchaudio.load(drums_original)
bass_spec = audios["bass"][:, frame_start: frame_end].cpu() bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original) bass, sample_rate = torchaudio.load(bass_original)
vocals_spec = audios["vocals"][:, frame_start: frame_end].cpu() vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original) vocals, sample_rate = torchaudio.load(vocals_original)
other_spec = audios["other"][:, frame_start: frame_end].cpu() other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original) other, sample_rate = torchaudio.load(other_original)
mix_spec = mixture[:, frame_start: frame_end].cpu() mix_spec = mixture[:, frame_start:frame_end].cpu()
###################################################################### ######################################################################
......
...@@ -37,6 +37,10 @@ print(torch.__version__) ...@@ -37,6 +37,10 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
import matplotlib.pyplot as plt
import mir_eval
from IPython.display import Audio
###################################################################### ######################################################################
# 2. Preparation # 2. Preparation
# -------------- # --------------
...@@ -59,10 +63,6 @@ print(torchaudio.__version__) ...@@ -59,10 +63,6 @@ print(torchaudio.__version__)
from pesq import pesq from pesq import pesq
from pystoi import stoi from pystoi import stoi
import mir_eval
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
###################################################################### ######################################################################
......
...@@ -45,30 +45,9 @@ import torchaudio ...@@ -45,30 +45,9 @@ import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
######################################################################
#
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.io import StreamReader
try:
from torchaudio.io import StreamReader
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install the requisite
third party libraries by running the following code block:
!add-apt-repository -y ppa:savoury1/ffmpeg4
!apt-get -qq install -y ffmpeg
"""
)
except ModuleNotFoundError:
pass
raise
###################################################################### ######################################################################
# 3. Construct the pipeline # 3. Construct the pipeline
...@@ -202,11 +181,11 @@ def _plot(feats, num_iter, unit=25): ...@@ -202,11 +181,11 @@ def _plot(feats, num_iter, unit=25):
fig, axes = plt.subplots(num_plots, 1) fig, axes = plt.subplots(num_plots, 1)
t0 = 0 t0 = 0
for i, ax in enumerate(axes): for i, ax in enumerate(axes):
feats_ = feats[i*unit:(i+1)*unit] feats_ = feats[i * unit : (i + 1) * unit]
t1 = t0 + segment_length / sample_rate * len(feats_) t1 = t0 + segment_length / sample_rate * len(feats_)
feats_ = torch.cat([f[2:-2] for f in feats_]) # remove boundary effect and overlap feats_ = torch.cat([f[2:-2] for f in feats_]) # remove boundary effect and overlap
ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower") ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
ax.tick_params(which='both', left=False, labelleft=False) ax.tick_params(which="both", left=False, labelleft=False)
ax.set_xlim(t0, t0 + unit_dur) ax.set_xlim(t0, t0 + unit_dur)
t0 = t1 t0 = t1
fig.suptitle("MelSpectrogram Feature") fig.suptitle("MelSpectrogram Feature")
......
...@@ -28,19 +28,18 @@ print(torchaudio.__version__) ...@@ -28,19 +28,18 @@ print(torchaudio.__version__)
# #
try: try:
from torchaudio.prototype.functional import ( from torchaudio.prototype.functional import adsr_envelope, oscillator_bank
oscillator_bank,
adsr_envelope,
)
except ModuleNotFoundError: except ModuleNotFoundError:
print( print(
"Failed to import prototype DSP features. " "Failed to import prototype DSP features. "
"Please install torchaudio nightly builds. " "Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally " "Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.") "for instructions to install a nightly build."
)
raise raise
import math import math
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio from IPython.display import Audio
...@@ -93,7 +92,7 @@ PI2 = 2 * torch.pi ...@@ -93,7 +92,7 @@ PI2 = 2 * torch.pi
# the rest of the tutorial. # the rest of the tutorial.
# #
F0 = 344. # fundamental frequency F0 = 344.0 # fundamental frequency
DURATION = 1.1 # [seconds] DURATION = 1.1 # [seconds]
SAMPLE_RATE = 16_000 # [Hz] SAMPLE_RATE = 16_000 # [Hz]
...@@ -102,26 +101,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -102,26 +101,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
###################################################################### ######################################################################
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.3): def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.3):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
axes[0].plot(t, freq) axes[0].plot(t, freq)
axes[0].set( axes[0].set(title=f"Oscillator bank (bank size: {amp.size(-1)})", ylabel="Frequency [Hz]", ylim=[-0.03, None])
title=f"Oscillator bank (bank size: {amp.size(-1)})",
ylabel="Frequency [Hz]",
ylim=[-0.03, None])
axes[1].plot(t, amp) axes[1].plot(t, amp)
axes[1].set( axes[1].set(ylabel="Amplitude", ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
ylabel="Amplitude",
ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[2].plot(t, waveform) axes[2].plot(t, waveform)
axes[2].set(ylabel="Waveform") axes[2].set(ylabel="Waveform")
axes[3].specgram(waveform, Fs=sample_rate) axes[3].specgram(waveform, Fs=sample_rate)
axes[3].set( axes[3].set(ylabel="Spectrogram", xlabel="Time [s]", xlim=[-0.01, t[-1] + 0.01])
ylabel="Spectrogram",
xlabel="Time [s]",
xlim=[-0.01, t[-1] + 0.01])
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
...@@ -147,7 +139,7 @@ amp = torch.ones((NUM_FRAMES, 1)) ...@@ -147,7 +139,7 @@ amp = torch.ones((NUM_FRAMES, 1))
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Combining multiple sine waves # Combining multiple sine waves
...@@ -166,7 +158,7 @@ amp = torch.ones((NUM_FRAMES, 3)) / 3 ...@@ -166,7 +158,7 @@ amp = torch.ones((NUM_FRAMES, 3)) / 3
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0)) show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
...@@ -279,7 +271,8 @@ amp = torch.stack( ...@@ -279,7 +271,8 @@ amp = torch.stack(
adsr_envelope(unit, attack=0.01, hold=0.125, decay=0.12, sustain=0.05, release=0), adsr_envelope(unit, attack=0.01, hold=0.125, decay=0.12, sustain=0.05, release=0),
adsr_envelope(unit, attack=0.01, hold=0.25, decay=0.08, sustain=0, release=0), adsr_envelope(unit, attack=0.01, hold=0.25, decay=0.08, sustain=0, release=0),
), ),
dim=-1) dim=-1,
)
amp = amp.repeat(repeat, 1) / 2 amp = amp.repeat(repeat, 1) / 2
bass = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) bass = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
...@@ -316,7 +309,7 @@ show(freq, amp, doremi, SAMPLE_RATE) ...@@ -316,7 +309,7 @@ show(freq, amp, doremi, SAMPLE_RATE)
# ~~~~~ # ~~~~~
# #
env = adsr_envelope(NUM_FRAMES * 6, attack=0.98, decay=0., sustain=1, release=0.02) env = adsr_envelope(NUM_FRAMES * 6, attack=0.98, decay=0.0, sustain=1, release=0.02)
tones = [ tones = [
484.90, # B4 484.90, # B4
......
...@@ -78,12 +78,11 @@ print(torchaudio.__version__) ...@@ -78,12 +78,11 @@ print(torchaudio.__version__)
# #
try: try:
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE
from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE
from pesq import pesq from pesq import pesq
from pystoi import stoi from pystoi import stoi
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
import google.colab import google.colab # noqa: F401
print( print(
""" """
...@@ -98,14 +97,15 @@ except ImportError: ...@@ -98,14 +97,15 @@ except ImportError:
) )
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# #
# #
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio.utils import download_asset
from IPython.display import Audio from IPython.display import Audio
import matplotlib.pyplot as plt from torchaudio.utils import download_asset
def si_snr(estimate, reference, epsilon=1e-8): def si_snr(estimate, reference, epsilon=1e-8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment