Unverified Commit 172260f9 authored by moto's avatar moto Committed by GitHub
Browse files

Update TimeStretch doc and tutorial (#3694)

parent 65df10bb
......@@ -25,6 +25,7 @@ print(torchaudio.__version__)
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
######################################################################
......@@ -69,11 +70,6 @@ def get_spectrogram(
return spectrogram(waveform)
def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
######################################################################
# SpecAugment
# -----------
......@@ -98,11 +94,15 @@ stretch = T.TimeStretch()
spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
#
######################################################################
# Visualization
# ~~~~~~~~~~~~~
def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto")
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
......@@ -112,6 +112,30 @@ def plot():
plot()
######################################################################
# Audio Samples
# ~~~~~~~~~~~~~
def preview(spec, rate=16000):
ispec = T.InverseSpectrogram()
waveform = ispec(spec)
return Audio(waveform[0].numpy().T, rate=rate)
preview(spec)
######################################################################
#
preview(spec_12)
######################################################################
#
preview(spec_09)
######################################################################
# Time and Frequency Masking
# --------------------------
......@@ -131,6 +155,10 @@ freq_masked = freq_masking(spec)
def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
......
......@@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
Proposed in *SpecAugment* :cite:`specaugment`.
Args:
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
hop_length (int or None, optional): Length of hop between STFT windows.
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
.. note::
The expected input is raw, complex-valued spectrogram.
Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
>>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
>>> stretch = torchaudio.transforms.TimeStretch()
>>>
>>> original = spectrogram(waveform)
>>> streched_1_2 = stretch(original, 1.2)
>>> streched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
:width: 600
:alt: Spectrogram streched by 1.2
>>> stretched_1_2 = stretch(original, 1.2)
>>> stretched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
:width: 600
:alt: The original spectrogram
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
:width: 600
:alt: Spectrogram streched by 0.9
:alt: The visualization of stretched spectrograms.
"""
__constants__ = ["fixed_rate"]
......@@ -1067,8 +1063,8 @@ class TimeStretch(torch.nn.Module):
Returns:
Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
"""
if overriding_rate is None:
if self.fixed_rate is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment