Commit 47716772 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix style to prep #3414 (#3415)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3415

Differential Revision: D46526437

Pulled By: mthrok

fbshipit-source-id: f78d19c19d7e68f67712412de35d9ed50f47263b
parent 91db978b
......@@ -8,7 +8,7 @@ to torchaudio with no labeling responsibility, so we don't want to bother them.
import json
import os
import sys
from typing import Any, Optional, Set, Tuple
from typing import Any, Optional, Set
import requests
......@@ -84,12 +84,12 @@ Use 'module: ops' for operations under 'torchaudio/{transforms, functional}', \
and ML-related components under 'torchaudio/csrc' (e.g. RNN-T loss).
Things in "examples" directory:
- 'recipe' is applicable to training recipes under the 'examples' folder,
- 'recipe' is applicable to training recipes under the 'examples' folder,
- 'tutorial' is applicable to tutorials under the “examples/tutorials” folder
- 'example' is applicable to everything else (e.g. C++ examples)
- 'module: docs' is applicable to code documentations (not to tutorials). \
- 'example' is applicable to everything else (e.g. C++ examples)
- 'module: docs' is applicable to code documentations (not to tutorials).
Regarding examples in code documentations, please also use 'module: docs'.
Regarding examples in code documentations, please also use 'module: docs'.
Please use 'other' tag only when you’re sure the changes are not much relevant to users, \
or when all other tags are not applicable. Try not to use it often, in order to minimize \
......@@ -98,7 +98,7 @@ efforts required when we prepare release notes.
---
When preparing release notes, please make sure 'documentation' and 'tutorials' occur as the \
last sub-categories under each primary category like 'new feature', 'improvements' or 'prototype'.
last sub-categories under each primary category like 'new feature', 'improvements' or 'prototype'.
Things related to build are by default excluded from the release note, \
except when it impacts users. For example:
......
......@@ -25,6 +25,7 @@ from datetime import datetime
sys.path.insert(0, os.path.abspath("."))
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
......
......@@ -142,7 +142,7 @@ class ResNet(nn.Module):
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
......
......@@ -174,7 +174,7 @@ class ResNet1D(nn.Module):
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
......
......@@ -70,8 +70,8 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None
available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
available_symbol_set = {"english_characters", "english_phonemes"}
available_phonemizers = {"DeepPhonemizer"}
def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
......
......@@ -35,17 +35,14 @@ print(torchaudio.__version__)
#
try:
from torchaudio.prototype.functional import (
oscillator_bank,
extend_pitch,
adsr_envelope,
)
from torchaudio.prototype.functional import adsr_envelope, extend_pitch, oscillator_bank
except ModuleNotFoundError:
print(
"Failed to import prototype DSP features. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.")
"for instructions to install a nightly build."
)
raise
import matplotlib.pyplot as plt
......@@ -78,7 +75,7 @@ from IPython.display import Audio
PI = torch.pi
PI2 = 2 * torch.pi
F0 = 344. # fundamental frequency
F0 = 344.0 # fundamental frequency
DURATION = 1.1 # [seconds]
SAMPLE_RATE = 16_000 # [Hz]
......@@ -87,26 +84,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
######################################################################
#
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
axes[0].plot(t, freq)
axes[0].set(
title=f"Oscillator bank (bank size: {amp.size(-1)})",
ylabel="Frequency [Hz]",
ylim=[-0.03, None])
axes[0].set(title=f"Oscillator bank (bank size: {amp.size(-1)})", ylabel="Frequency [Hz]", ylim=[-0.03, None])
axes[1].plot(t, amp)
axes[1].set(
ylabel="Amplitude",
ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[1].set(ylabel="Amplitude", ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[2].plot(t, waveform)
axes[2].set(ylabel="Waveform")
axes[3].specgram(waveform, Fs=sample_rate)
axes[3].set(
ylabel="Spectrogram",
xlabel="Time [s]",
xlim=[-0.01, t[-1] + 0.01])
axes[3].set(ylabel="Spectrogram", xlabel="Time [s]", xlim=[-0.01, t[-1] + 0.01])
for i in range(4):
axes[i].grid(True)
......@@ -121,6 +111,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
waveform /= waveform.abs().max()
return Audio(vol * waveform, rate=sample_rate, normalize=False)
######################################################################
# Harmonic Overtones
# -------------------
......@@ -159,10 +150,11 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
# and adds extend pitch in accordance with the formula above.
#
def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq = extend_pitch(freq0, num_pitches)
mults = [-((-1) ** i) / (PI * i) for i in range(1, 1+num_pitches)]
mults = [-((-1) ** i) / (PI * i) for i in range(1, 1 + num_pitches)]
amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
return freq, amp, waveform
......@@ -176,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
#
......@@ -191,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Square wave
......@@ -212,22 +204,23 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
def square_wave(freq0, amp0, num_pitches, sample_rate):
mults = [2. * i + 1. for i in range(num_pitches)]
mults = [2.0 * i + 1.0 for i in range(num_pitches)]
freq = extend_pitch(freq0, mults)
mults = [4 / (PI * (2. * i + 1.)) for i in range(num_pitches)]
mults = [4 / (PI * (2.0 * i + 1.0)) for i in range(num_pitches)]
amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
return freq, amp, waveform
######################################################################
#
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE/F0/2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Triangle wave
......@@ -248,11 +241,11 @@ show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
def triangle_wave(freq0, amp0, num_pitches, sample_rate):
mults = [2. * i + 1. for i in range(num_pitches)]
mults = [2.0 * i + 1.0 for i in range(num_pitches)]
freq = extend_pitch(freq0, mults)
c = 8 / (PI ** 2)
mults = [c * ((-1) ** i) / ((2. * i + 1.) ** 2) for i in range(num_pitches)]
c = 8 / (PI**2)
mults = [c * ((-1) ** i) / ((2.0 * i + 1.0) ** 2) for i in range(num_pitches)]
amp = extend_pitch(amp0, mults)
waveform = oscillator_bank(freq, amp, sample_rate=sample_rate)
......@@ -263,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Inharmonic Paritials
......@@ -288,18 +281,18 @@ duration = 2.0
num_frames = int(SAMPLE_RATE * duration)
freq0 = torch.full((num_frames, 1), F0)
mults = [0.56, 0.92, 1.19, 1.71, 2, 2.74, 3., 3.76, 4.07]
mults = [0.56, 0.92, 1.19, 1.71, 2, 2.74, 3.0, 3.76, 4.07]
freq = extend_pitch(freq0, mults)
amp = adsr_envelope(
num_frames=num_frames,
attack=0.002,
decay=0.998,
sustain=0.,
release=0.,
sustain=0.0,
release=0.0,
n_decay=2,
)
amp = torch.stack([amp * (0.5 ** i) for i in range(num_tones)], dim=-1)
amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
......
......@@ -207,6 +207,7 @@ from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
......
......@@ -20,6 +20,8 @@ import torchaudio.functional as F
print(torch.__version__)
print(torchaudio.__version__)
import matplotlib.pyplot as plt
######################################################################
# Preparation
# -----------
......@@ -28,7 +30,6 @@ print(torchaudio.__version__)
#
from IPython.display import Audio
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
......@@ -101,6 +102,7 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
figure.suptitle(title)
plt.show(block=False)
######################################################################
#
......@@ -122,6 +124,7 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
figure.suptitle(title)
plt.show(block=False)
######################################################################
# Original
# ~~~~~~~~
......@@ -353,10 +356,12 @@ bg_added = F.add_noise(rir_applied, noise, snr_db)
plot_specgram(bg_added, sample_rate, title="BG noise added")
# Apply filtering and change sample rate
effect = ",".join([
"lowpass=frequency=4000:poles=1",
"compand=attacks=0.02:decays=0.05:points=-60/-60|-30/-10|-20/-8|-5/-8|-2/-8:gain=-8:volume=-7:delay=0.05",
])
effect = ",".join(
[
"lowpass=frequency=4000:poles=1",
"compand=attacks=0.02:decays=0.05:points=-60/-60|-30/-10|-20/-8|-5/-8|-2/-8:gain=-8:volume=-7:delay=0.05",
]
)
filtered = apply_effect(bg_added.T, sample_rate, effect)
sample_rate2 = 8000
......
......@@ -25,6 +25,9 @@ import torchaudio.transforms as T
print(torch.__version__)
print(torchaudio.__version__)
import librosa
import matplotlib.pyplot as plt
######################################################################
# Preparation
# -----------
......@@ -38,8 +41,6 @@ print(torchaudio.__version__)
# !pip install librosa
#
from IPython.display import Audio
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
torch.random.manual_seed(0)
......@@ -388,6 +389,7 @@ pitch = F.detect_pitch_frequency(SPEECH_WAVEFORM, SAMPLE_RATE)
######################################################################
#
def plot_pitch(waveform, sr, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
......
......@@ -61,6 +61,7 @@ def _hide_seek(obj):
def read(self, n):
return self.obj.read(n)
return _wrapper(obj)
......@@ -294,7 +295,8 @@ with requests.get(url, stream=True) as response:
print("Fetching until the requested frames are available...")
with requests.get(url, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load(
_hide_seek(response.raw), frame_offset=frame_offset, num_frames=num_frames)
_hide_seek(response.raw), frame_offset=frame_offset, num_frames=num_frames
)
print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="")
......@@ -333,6 +335,7 @@ waveform, sample_rate = torchaudio.load(SAMPLE_WAV)
######################################################################
#
def inspect_file(path):
print("-" * 10)
print("Source:", path)
......@@ -341,6 +344,7 @@ def inspect_file(path):
print(f" - {torchaudio.info(path)}")
print()
######################################################################
#
# Save without any encoding option.
......
......@@ -27,14 +27,14 @@ import math
import timeit
import librosa
import resampy
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import Audio, display
import resampy
from IPython.display import Audio
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
DEFAULT_OFFSET = 201
......@@ -338,6 +338,7 @@ print(f"resampy: {resampy.__version__}")
######################################################################
#
def benchmark_resample_functional(
waveform,
sample_rate,
......@@ -348,8 +349,9 @@ def benchmark_resample_functional(
beta=None,
iters=5,
):
return timeit.timeit(
stmt='''
return (
timeit.timeit(
stmt="""
torchaudio.functional.resample(
waveform,
sample_rate,
......@@ -359,16 +361,20 @@ torchaudio.functional.resample(
resampling_method=resampling_method,
beta=beta,
)
''',
setup='import torchaudio',
number=iters,
globals=locals(),
) * 1000 / iters
""",
setup="import torchaudio",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
######################################################################
#
def benchmark_resample_transforms(
waveform,
sample_rate,
......@@ -379,9 +385,10 @@ def benchmark_resample_transforms(
beta=None,
iters=5,
):
return timeit.timeit(
stmt='resampler(waveform)',
setup='''
return (
timeit.timeit(
stmt="resampler(waveform)",
setup="""
import torchaudio
resampler = torchaudio.transforms.Resample(
......@@ -394,15 +401,19 @@ resampler = torchaudio.transforms.Resample(
beta=beta,
)
resampler.to(waveform.device)
''',
number=iters,
globals=locals(),
) * 1000 / iters
""",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
######################################################################
#
def benchmark_resample_librosa(
waveform,
sample_rate,
......@@ -411,24 +422,29 @@ def benchmark_resample_librosa(
iters=5,
):
waveform_np = waveform.squeeze().numpy()
return timeit.timeit(
stmt='''
return (
timeit.timeit(
stmt="""
librosa.resample(
waveform_np,
orig_sr=sample_rate,
target_sr=resample_rate,
res_type=res_type,
)
''',
setup='import librosa',
number=iters,
globals=locals(),
) * 1000 / iters
""",
setup="import librosa",
number=iters,
globals=locals(),
)
* 1000
/ iters
)
######################################################################
#
def benchmark(sample_rate, resample_rate):
times, rows = [], []
waveform = get_sine_sweep(sample_rate).to(torch.float32)
......@@ -483,7 +499,7 @@ def plot(df):
print(df.round(2))
ax = df.plot(kind="bar")
plt.ylabel("Time Elapsed [ms]")
plt.xticks(rotation = 0, fontsize=10)
plt.xticks(rotation=0, fontsize=10)
for cont, col, color in zip(ax.containers, df.columns, mcolors.TABLEAU_COLORS):
label = ["N/A" if v != v else str(v) for v in df[col].round(2)]
ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small")
......
......@@ -10,7 +10,7 @@ This tutorial shows how to align transcripts to speech with
`“Scaling Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__,
and two advanced usages, i.e. dealing with non-English data and
transcription errors.
transcription errors.
Though there’s some overlap in visualization
diagrams, the scope here is different from the `“Forced Alignment with
......@@ -39,14 +39,10 @@ except ModuleNotFoundError:
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.")
"for instructions to install a nightly build."
)
raise
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Audio
######################################################################
# I. Basic usages
# ---------------
......@@ -71,7 +67,10 @@ from IPython.display import Audio
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
......@@ -193,8 +192,6 @@ plt.show()
# token-level and word-level alignments easily.
#
import torchaudio.functional as F
@dataclass
class Frame:
......@@ -214,7 +211,7 @@ def compute_alignments(transcript, dictionary, emission):
target_lengths = torch.tensor(targets.shape[0])
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = F.forced_align(emission, targets, input_lengths, target_lengths, 0)
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert len(frame_alignment) == input_lengths.item()
assert len(targets) == target_lengths.item()
......@@ -279,8 +276,6 @@ def merge_repeats(frames, transcript):
while i2 < len(frames) and frames[i1].token_index == frames[i2].token_index:
i2 += 1
score = sum(frames[k].score for k in range(i1, i2)) / (i2 - i1)
tokens = [dictionary[c] if c in dictionary else dictionary['@'] for c in transcript.replace(" ", "")]
segments.append(
Segment(
transcript_nospace[frames[i1].token_index],
......@@ -370,7 +365,7 @@ def merge_words(transcript, segments, separator=" "):
s = len(words)
else:
s = 0
segs = segments[i1 + s:i2 + s]
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
......@@ -380,6 +375,7 @@ def merge_words(transcript, segments, separator=" "):
i3 += 1
return words
word_segments = merge_words(transcript, segments, "|")
......@@ -388,9 +384,10 @@ word_segments = merge_words(transcript, segments, "|")
# ^^^^^^^^^^^^^
#
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({'font.size': 30})
plt.rcParams.update({"font.size": 30})
# The original waveform
ratio = waveform.size(0) / input_lengths
......@@ -413,6 +410,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
plot_alignments(
segments,
word_segments,
......@@ -437,6 +435,7 @@ def display_segment(i, waveform, word_segments, frame_alignment):
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
# Generate the audio for each segment
print(transcript)
IPython.display.Audio(SPEECH_FILE)
......@@ -532,7 +531,9 @@ model = wav2vec2_model(
aux_num_out=31,
)
torch.hub.download_url_to_file("https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", "model.pt")
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", "model.pt"
)
checkpoint = torch.load("model.pt", map_location="cpu")
model.load_state_dict(checkpoint)
......@@ -556,6 +557,7 @@ def get_emission(waveform):
emission = emissions[0].cpu().detach()
return emission, waveform
emission, waveform = get_emission(waveform)
# Construct the dictionary
......@@ -602,15 +604,11 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments)
plot_alignments(
segments,
word_segments,
waveform[0],
emission.shape[0]
)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plt.show()
return word_segments, frame_alignment
# One can follow the following steps to download the uroman romanizer and use it to obtain normalized transcripts.
# def normalize_uroman(text):
# text = text.lower()
......@@ -622,14 +620,16 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
# echo 'aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid' > test.txt"
# git clone https://github.com/isi-nlp/uroman
# uroman/bin/uroman.pl < test.txt > test_romanized.txt
#
#
# file = "test_romanized.txt"
# f = open(file, "r")
# lines = f.readlines()
# text_normalized = normalize_uroman(lines[0].strip())
text_normalized = "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac")
waveform, _ = torchaudio.load(SPEECH_FILE)
......
......@@ -60,10 +60,7 @@ try:
for k, v in torchaudio.utils.ffmpeg_utils.get_versions().items():
print(k, v)
except Exception:
raise RuntimeError(
"This tutorial requires FFmpeg libraries 4.2>,<5. "
"Please install FFmpeg."
)
raise RuntimeError("This tutorial requires FFmpeg libraries 4.2>,<5. " "Please install FFmpeg.")
######################################################################
# Usage
......@@ -107,11 +104,11 @@ waveform, sr = torchaudio.load(src, channels_first=False)
#
def show(effect=None, format=None, *, stereo=False):
def show(effect, *, stereo=False):
wf = torch.cat([waveform] * 2, dim=1) if stereo else waveform
figsize = (6.4, 2.1 if stereo else 1.2)
effector = AudioEffector(effect=effect, format=format, pad_end=False)
effector = AudioEffector(effect=effect, pad_end=False)
result = effector.apply(wf, int(sr))
num_channels = result.size(1)
......@@ -128,7 +125,7 @@ def show(effect=None, format=None, *, stereo=False):
# --------
#
show(effect=None, format=None)
show(effect=None)
######################################################################
# Effects
......@@ -139,131 +136,138 @@ show(effect=None, format=None)
# tempo
# ~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#atempo
show(effect="atempo=0.7")
show("atempo=0.7")
######################################################################
#
show(effect="atempo=1.8")
show("atempo=1.8")
######################################################################
# highpass
# ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#highpass
show(effect="highpass=frequency=1500")
show("highpass=frequency=1500")
######################################################################
# lowpass
# ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#lowpass
show(effect="lowpass=frequency=1000")
show("lowpass=frequency=1000")
######################################################################
# allpass
# ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#allpass
show(effect="allpass")
show("allpass")
######################################################################
# bandpass
# ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#bandpass
show(effect="bandpass=frequency=3000")
show("bandpass=frequency=3000")
######################################################################
# bandreject
# ~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#bandreject
show(effect="bandreject=frequency=3000")
show("bandreject=frequency=3000")
######################################################################
# echo
# ~~~~
# https://ffmpeg.org/ffmpeg-filters.html#aecho
show(effect="aecho=in_gain=0.8:out_gain=0.88:delays=6:decays=0.4")
show("aecho=in_gain=0.8:out_gain=0.88:delays=6:decays=0.4")
######################################################################
#
show(effect="aecho=in_gain=0.8:out_gain=0.88:delays=60:decays=0.4")
show("aecho=in_gain=0.8:out_gain=0.88:delays=60:decays=0.4")
######################################################################
#
show(effect="aecho=in_gain=0.8:out_gain=0.9:delays=1000:decays=0.3")
show("aecho=in_gain=0.8:out_gain=0.9:delays=1000:decays=0.3")
######################################################################
# chorus
# ~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#chorus
show(effect=("chorus=0.5:0.9:50|60|40:0.4|0.32|0.3:0.25|0.4|0.3:2|2.3|1.3"))
show("chorus=0.5:0.9:50|60|40:0.4|0.32|0.3:0.25|0.4|0.3:2|2.3|1.3")
######################################################################
# fft filter
# ~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#afftfilt
show(effect=(
# fmt: off
show(
"afftfilt="
"real='re * (1-clip(b * (b/nb), 0, 1))':"
"imag='im * (1-clip(b * (b/nb), 0, 1))'"))
"imag='im * (1-clip(b * (b/nb), 0, 1))'"
)
######################################################################
#
show(effect=(
show(
"afftfilt="
"real='hypot(re,im) * sin(0)':"
"imag='hypot(re,im) * cos(0)':"
"win_size=512:"
"overlap=0.75"))
"overlap=0.75"
)
######################################################################
#
show(effect=(
show(
"afftfilt="
"real='hypot(re,im) * cos(2 * 3.14 * (random(0) * 2-1))':"
"imag='hypot(re,im) * sin(2 * 3.14 * (random(1) * 2-1))':"
"win_size=128:"
"overlap=0.8"))
"overlap=0.8"
)
# fmt: on
######################################################################
# vibrato
# ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#vibrato
show(effect=("vibrato=f=10:d=0.8"))
show("vibrato=f=10:d=0.8")
######################################################################
# tremolo
# ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#tremolo
show(effect=("tremolo=f=8:d=0.8"))
show("tremolo=f=8:d=0.8")
######################################################################
# crystalizer
# ~~~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#crystalizer
show(effect=("crystalizer"))
show("crystalizer")
######################################################################
# flanger
# ~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#flanger
show(effect=("flanger"))
show("flanger")
######################################################################
# phaser
# ~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#aphaser
show(effect=("aphaser"))
show("aphaser")
######################################################################
# pulsator
# ~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#apulsator
show(effect=("apulsator"), stereo=True)
show("apulsator", stereo=True)
######################################################################
# haas
# ~~~~
# https://ffmpeg.org/ffmpeg-filters.html#haas
show(effect=("haas"))
show("haas")
######################################################################
# Codecs
......@@ -292,11 +296,13 @@ def show_multi(configs):
# ~~~
#
results = show_multi([
{"format": "ogg"},
{"format": "ogg", "encoder": "vorbis"},
{"format": "ogg", "encoder": "opus"},
])
results = show_multi(
[
{"format": "ogg"},
{"format": "ogg", "encoder": "vorbis"},
{"format": "ogg", "encoder": "opus"},
]
)
######################################################################
# ogg - default encoder (flac)
......@@ -321,15 +327,17 @@ results[2]
# ~~~
# https://trac.ffmpeg.org/wiki/Encode/MP3
results = show_multi([
{"format": "mp3"},
{"format": "mp3", "codec_config": CodecConfig(compression_level=1)},
{"format": "mp3", "codec_config": CodecConfig(compression_level=9)},
{"format": "mp3", "codec_config": CodecConfig(bit_rate=192_000)},
{"format": "mp3", "codec_config": CodecConfig(bit_rate=8_000)},
{"format": "mp3", "codec_config": CodecConfig(qscale=9)},
{"format": "mp3", "codec_config": CodecConfig(qscale=1)},
])
results = show_multi(
[
{"format": "mp3"},
{"format": "mp3", "codec_config": CodecConfig(compression_level=1)},
{"format": "mp3", "codec_config": CodecConfig(compression_level=9)},
{"format": "mp3", "codec_config": CodecConfig(bit_rate=192_000)},
{"format": "mp3", "codec_config": CodecConfig(bit_rate=8_000)},
{"format": "mp3", "codec_config": CodecConfig(qscale=9)},
{"format": "mp3", "codec_config": CodecConfig(qscale=1)},
]
)
######################################################################
# default
......
......@@ -27,14 +27,11 @@ import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
import matplotlib.pyplot as plt
######################################################################
#
from torchaudio.prototype.functional import (
sinc_impulse_response,
frequency_impulse_response,
)
import matplotlib.pyplot as plt
from torchaudio.prototype.functional import frequency_impulse_response, sinc_impulse_response
######################################################################
#
......@@ -75,7 +72,7 @@ import matplotlib.pyplot as plt
# :py:func:`~torchaudio.prototype.functional.sinc_impulse_response`.
#
cutoff = torch.linspace(0., 1., 9)
cutoff = torch.linspace(0.0, 1.0, 9)
irs = sinc_impulse_response(cutoff, window_size=513)
print("Cutoff shape:", cutoff.shape)
......@@ -87,6 +84,7 @@ print("Impulse response shape:", irs.shape)
# Let's visualize the resulting impulse responses.
#
def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape
half = window_size // 2
......@@ -99,7 +97,8 @@ def plot_sinc_ir(irs, cutoff):
ax.grid(True)
fig.suptitle(
"Impulse response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)")
"(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout()
......@@ -126,12 +125,12 @@ frs = torch.fft.rfft(irs, n=2048, dim=1).abs()
# Let's visualize the resulting frequency responses.
#
def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(
num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -141,11 +140,12 @@ def plot_sinc_fr(frs, cutoff, band=False):
yticks=[1e-9, 1e-6, 1e-3, 1],
xticks=torch.linspace(0, num_fft, num_ticks),
xticklabels=[f"{i/(num_ticks - 1)}" for i in range(num_ticks)],
xlabel="Frequency"
xlabel="Frequency",
)
fig.suptitle(
"Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)")
"(Frequencies are relative to Nyquist frequency)"
)
plt.tight_layout()
......@@ -193,13 +193,11 @@ plot_sinc_fr(frs, cutoff)
# Band-pass filter can be obtained by subtracting low-pass filter for
# upper band from that of lower band.
cutoff = torch.linspace(0., 1, 11)
cutoff = torch.linspace(0.0, 1, 11)
c_low = cutoff[:-1]
c_high = cutoff[1:]
irs = (
sinc_impulse_response(c_low, window_size=513)
- sinc_impulse_response(c_high, window_size=513))
irs = sinc_impulse_response(c_low, window_size=513) - sinc_impulse_response(c_high, window_size=513)
frs = torch.fft.rfft(irs, n=2048, dim=1).abs()
######################################################################
......@@ -256,6 +254,7 @@ print("Impulse Response:", ir.shape)
######################################################################
#
def plot_ir(magnitudes, ir, num_fft=2048):
fr = torch.fft.rfft(ir, n=num_fft, dim=0).abs()
ir_size = ir.size(-1)
......@@ -268,17 +267,18 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[0].set(title="Impulse Response")
axes[0].set_xticks([i * half // 4 for i in range(-4, 5)])
t = torch.linspace(0, 1, fr.numel())
axes[1].plot(t, fr, label='Actual')
axes[2].semilogy(t, fr, label='Actual')
axes[1].plot(t, fr, label="Actual")
axes[2].semilogy(t, fr, label="Actual")
t = torch.linspace(0, 1, magnitudes.numel())
for i in range(1, 3):
axes[i].plot(t, magnitudes, label='Desired (input)', linewidth=1.1, linestyle='--')
axes[i].plot(t, magnitudes, label="Desired (input)", linewidth=1.1, linestyle="--")
axes[i].grid(True)
axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right")
fig.tight_layout()
######################################################################
#
......@@ -305,7 +305,7 @@ plot_ir(magnitudes, ir)
#
#
magnitudes = torch.linspace(0, 1, 64)**4.0
magnitudes = torch.linspace(0, 1, 64) ** 4.0
ir = frequency_impulse_response(magnitudes)
......@@ -316,7 +316,7 @@ plot_ir(magnitudes, ir)
######################################################################
#
magnitudes = torch.sin(torch.linspace(0, 10, 64))**4.0
magnitudes = torch.sin(torch.linspace(0, 10, 64)) ** 4.0
ir = frequency_impulse_response(magnitudes)
......
......@@ -450,6 +450,7 @@ plot_alignments(
)
plt.show()
################################################################################
#
......
......@@ -45,6 +45,8 @@ import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
import matplotlib.pyplot as plt
######################################################################
# In addition to ``torchaudio``, ``mir_eval`` is required to perform
# signal-to-distortion ratio (SDR) calculations. To install ``mir_eval``
......@@ -52,30 +54,9 @@ print(torchaudio.__version__)
#
from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset
import matplotlib.pyplot as plt
try:
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from mir_eval import separation
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install mir_eval
"""
)
except ModuleNotFoundError:
pass
raise
######################################################################
# 3. Construct the pipeline
......@@ -130,11 +111,11 @@ from torchaudio.transforms import Fade
def separate_sources(
model,
mix,
segment=10.,
overlap=0.1,
device=None,
model,
mix,
segment=10.0,
overlap=0.1,
device=None,
):
"""
Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.
......@@ -157,7 +138,7 @@ def separate_sources(
start = 0
end = chunk_len
overlap_frames = overlap * sample_rate
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
final = torch.zeros(batch, len(model.sources), channels, length, device=device)
......@@ -265,12 +246,13 @@ stft = torchaudio.transforms.Spectrogram(
# scores.
#
def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
print("SDR score is:",
separation.bss_eval_sources(
original_source.detach().numpy(),
predicted_source.detach().numpy())[0].mean())
plot_spectrogram(stft(predicted_source)[0], f'Spectrogram {source}')
print(
"SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
)
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
return Audio(predicted_source, rate=sample_rate)
......@@ -285,19 +267,19 @@ bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")
drums_spec = audios["drums"][:, frame_start: frame_end].cpu()
drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)
bass_spec = audios["bass"][:, frame_start: frame_end].cpu()
bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)
vocals_spec = audios["vocals"][:, frame_start: frame_end].cpu()
vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)
other_spec = audios["other"][:, frame_start: frame_end].cpu()
other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)
mix_spec = mixture[:, frame_start: frame_end].cpu()
mix_spec = mixture[:, frame_start:frame_end].cpu()
######################################################################
......
......@@ -37,6 +37,10 @@ print(torch.__version__)
print(torchaudio.__version__)
import matplotlib.pyplot as plt
import mir_eval
from IPython.display import Audio
######################################################################
# 2. Preparation
# --------------
......@@ -59,10 +63,6 @@ print(torchaudio.__version__)
from pesq import pesq
from pystoi import stoi
import mir_eval
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
######################################################################
......
......@@ -45,30 +45,9 @@ import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
import IPython
import matplotlib.pyplot as plt
try:
from torchaudio.io import StreamReader
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install the requisite
third party libraries by running the following code block:
!add-apt-repository -y ppa:savoury1/ffmpeg4
!apt-get -qq install -y ffmpeg
"""
)
except ModuleNotFoundError:
pass
raise
from torchaudio.io import StreamReader
######################################################################
# 3. Construct the pipeline
......@@ -202,11 +181,11 @@ def _plot(feats, num_iter, unit=25):
fig, axes = plt.subplots(num_plots, 1)
t0 = 0
for i, ax in enumerate(axes):
feats_ = feats[i*unit:(i+1)*unit]
feats_ = feats[i * unit : (i + 1) * unit]
t1 = t0 + segment_length / sample_rate * len(feats_)
feats_ = torch.cat([f[2:-2] for f in feats_]) # remove boundary effect and overlap
ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
ax.tick_params(which='both', left=False, labelleft=False)
ax.tick_params(which="both", left=False, labelleft=False)
ax.set_xlim(t0, t0 + unit_dur)
t0 = t1
fig.suptitle("MelSpectrogram Feature")
......
......@@ -28,19 +28,18 @@ print(torchaudio.__version__)
#
try:
from torchaudio.prototype.functional import (
oscillator_bank,
adsr_envelope,
)
from torchaudio.prototype.functional import adsr_envelope, oscillator_bank
except ModuleNotFoundError:
print(
"Failed to import prototype DSP features. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build.")
"for instructions to install a nightly build."
)
raise
import math
import matplotlib.pyplot as plt
from IPython.display import Audio
......@@ -93,7 +92,7 @@ PI2 = 2 * torch.pi
# the rest of the tutorial.
#
F0 = 344. # fundamental frequency
F0 = 344.0 # fundamental frequency
DURATION = 1.1 # [seconds]
SAMPLE_RATE = 16_000 # [Hz]
......@@ -102,26 +101,19 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
######################################################################
#
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.3):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
axes[0].plot(t, freq)
axes[0].set(
title=f"Oscillator bank (bank size: {amp.size(-1)})",
ylabel="Frequency [Hz]",
ylim=[-0.03, None])
axes[0].set(title=f"Oscillator bank (bank size: {amp.size(-1)})", ylabel="Frequency [Hz]", ylim=[-0.03, None])
axes[1].plot(t, amp)
axes[1].set(
ylabel="Amplitude",
ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[1].set(ylabel="Amplitude", ylim=[-0.03 if torch.all(amp >= 0.0) else None, None])
axes[2].plot(t, waveform)
axes[2].set(ylabel="Waveform")
axes[3].specgram(waveform, Fs=sample_rate)
axes[3].set(
ylabel="Spectrogram",
xlabel="Time [s]",
xlim=[-0.01, t[-1] + 0.01])
axes[3].set(ylabel="Spectrogram", xlabel="Time [s]", xlim=[-0.01, t[-1] + 0.01])
for i in range(4):
axes[i].grid(True)
......@@ -147,7 +139,7 @@ amp = torch.ones((NUM_FRAMES, 1))
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Combining multiple sine waves
......@@ -166,7 +158,7 @@ amp = torch.ones((NUM_FRAMES, 3)) / 3
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1/F0, 3/F0))
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
......@@ -279,7 +271,8 @@ amp = torch.stack(
adsr_envelope(unit, attack=0.01, hold=0.125, decay=0.12, sustain=0.05, release=0),
adsr_envelope(unit, attack=0.01, hold=0.25, decay=0.08, sustain=0, release=0),
),
dim=-1)
dim=-1,
)
amp = amp.repeat(repeat, 1) / 2
bass = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
......@@ -316,7 +309,7 @@ show(freq, amp, doremi, SAMPLE_RATE)
# ~~~~~
#
env = adsr_envelope(NUM_FRAMES * 6, attack=0.98, decay=0., sustain=1, release=0.02)
env = adsr_envelope(NUM_FRAMES * 6, attack=0.98, decay=0.0, sustain=1, release=0.02)
tones = [
484.90, # B4
......
......@@ -78,12 +78,11 @@ print(torchaudio.__version__)
#
try:
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE
from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE
from pesq import pesq
from pystoi import stoi
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
import google.colab
import google.colab # noqa: F401
print(
"""
......@@ -98,14 +97,15 @@ except ImportError:
)
import matplotlib.pyplot as plt
######################################################################
#
#
import torchaudio.functional as F
from torchaudio.utils import download_asset
from IPython.display import Audio
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
def si_snr(estimate, reference, epsilon=1e-8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment