Unverified Commit 172260f9 authored by moto's avatar moto Committed by GitHub
Browse files

Update TimeStretch doc and tutorial (#3694)

parent 65df10bb
...@@ -25,6 +25,7 @@ print(torchaudio.__version__) ...@@ -25,6 +25,7 @@ print(torchaudio.__version__)
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
###################################################################### ######################################################################
...@@ -69,11 +70,6 @@ def get_spectrogram( ...@@ -69,11 +70,6 @@ def get_spectrogram(
return spectrogram(waveform) return spectrogram(waveform)
def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
###################################################################### ######################################################################
# SpecAugment # SpecAugment
# ----------- # -----------
...@@ -98,11 +94,15 @@ stretch = T.TimeStretch() ...@@ -98,11 +94,15 @@ stretch = T.TimeStretch()
spec_12 = stretch(spec, overriding_rate=1.2) spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9) spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
#
######################################################################
# Visualization
# ~~~~~~~~~~~~~
def plot(): def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto")
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2") plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original") plot_spec(axes[1], torch.abs(spec[0]), title="Original")
...@@ -112,6 +112,30 @@ def plot(): ...@@ -112,6 +112,30 @@ def plot():
plot() plot()
######################################################################
# Audio Samples
# ~~~~~~~~~~~~~
def preview(spec, rate=16000):
ispec = T.InverseSpectrogram()
waveform = ispec(spec)
return Audio(waveform[0].numpy().T, rate=rate)
preview(spec)
######################################################################
#
preview(spec_12)
######################################################################
#
preview(spec_09)
###################################################################### ######################################################################
# Time and Frequency Masking # Time and Frequency Masking
# -------------------------- # --------------------------
...@@ -131,6 +155,10 @@ freq_masked = freq_masking(spec) ...@@ -131,6 +155,10 @@ freq_masked = freq_masking(spec)
def plot(): def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original") plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis") plot_spec(axes[1], time_masked[0], title="Masked along time axis")
......
...@@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module): ...@@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
Proposed in *SpecAugment* :cite:`specaugment`. Proposed in *SpecAugment* :cite:`specaugment`.
Args: Args:
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) hop_length (int or None, optional): Length of hop between STFT windows.
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``) n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float or None, optional): rate to speed up or slow down by. fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``) If None is provided, rate must be passed to the forward method. (Default: ``None``)
.. note::
The expected input is raw, complex-valued spectrogram.
Example Example
>>> spectrogram = torchaudio.transforms.Spectrogram() >>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
>>> stretch = torchaudio.transforms.TimeStretch() >>> stretch = torchaudio.transforms.TimeStretch()
>>> >>>
>>> original = spectrogram(waveform) >>> original = spectrogram(waveform)
>>> streched_1_2 = stretch(original, 1.2) >>> stretched_1_2 = stretch(original, 1.2)
>>> streched_0_9 = stretch(original, 0.9) >>> stretched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
:width: 600
:alt: Spectrogram streched by 1.2
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
:width: 600 :width: 600
:alt: The original spectrogram :alt: The visualization of stretched spectrograms.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
:width: 600
:alt: Spectrogram streched by 0.9
""" """
__constants__ = ["fixed_rate"] __constants__ = ["fixed_rate"]
...@@ -1067,8 +1063,8 @@ class TimeStretch(torch.nn.Module): ...@@ -1067,8 +1063,8 @@ class TimeStretch(torch.nn.Module):
Returns: Returns:
Tensor: Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
""" """
if overriding_rate is None: if overriding_rate is None:
if self.fixed_rate is None: if self.fixed_rate is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment