Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
172260f9
Unverified
Commit
172260f9
authored
Nov 09, 2023
by
moto
Committed by
GitHub
Nov 09, 2023
Browse files
Update TimeStretch doc and tutorial (#3694)
parent
65df10bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
25 deletions
+49
-25
examples/tutorials/audio_feature_augmentation_tutorial.py
examples/tutorials/audio_feature_augmentation_tutorial.py
+36
-8
src/torchaudio/transforms/_transforms.py
src/torchaudio/transforms/_transforms.py
+13
-17
No files found.
examples/tutorials/audio_feature_augmentation_tutorial.py
View file @
172260f9
...
...
@@ -25,6 +25,7 @@ print(torchaudio.__version__)
import
librosa
import
matplotlib.pyplot
as
plt
from
IPython.display
import
Audio
from
torchaudio.utils
import
download_asset
######################################################################
...
...
@@ -69,11 +70,6 @@ def get_spectrogram(
return
spectrogram
(
waveform
)
def
plot_spec
(
ax
,
spec
,
title
,
ylabel
=
"freq_bin"
):
ax
.
set_title
(
title
)
ax
.
imshow
(
librosa
.
power_to_db
(
spec
),
origin
=
"lower"
,
aspect
=
"auto"
)
######################################################################
# SpecAugment
# -----------
...
...
@@ -98,11 +94,15 @@ stretch = T.TimeStretch()
spec_12
=
stretch
(
spec
,
overriding_rate
=
1.2
)
spec_09
=
stretch
(
spec
,
overriding_rate
=
0.9
)
######################################################################
#
######################################################################
# Visualization
# ~~~~~~~~~~~~~
def
plot
():
def
plot_spec
(
ax
,
spec
,
title
):
ax
.
set_title
(
title
)
ax
.
imshow
(
librosa
.
amplitude_to_db
(
spec
),
origin
=
"lower"
,
aspect
=
"auto"
)
fig
,
axes
=
plt
.
subplots
(
3
,
1
,
sharex
=
True
,
sharey
=
True
)
plot_spec
(
axes
[
0
],
torch
.
abs
(
spec_12
[
0
]),
title
=
"Stretched x1.2"
)
plot_spec
(
axes
[
1
],
torch
.
abs
(
spec
[
0
]),
title
=
"Original"
)
...
...
@@ -112,6 +112,30 @@ def plot():
plot
()
######################################################################
# Audio Samples
# ~~~~~~~~~~~~~
def
preview
(
spec
,
rate
=
16000
):
ispec
=
T
.
InverseSpectrogram
()
waveform
=
ispec
(
spec
)
return
Audio
(
waveform
[
0
].
numpy
().
T
,
rate
=
rate
)
preview
(
spec
)
######################################################################
#
preview
(
spec_12
)
######################################################################
#
preview
(
spec_09
)
######################################################################
# Time and Frequency Masking
# --------------------------
...
...
@@ -131,6 +155,10 @@ freq_masked = freq_masking(spec)
def
plot
():
def
plot_spec
(
ax
,
spec
,
title
):
ax
.
set_title
(
title
)
ax
.
imshow
(
librosa
.
power_to_db
(
spec
),
origin
=
"lower"
,
aspect
=
"auto"
)
fig
,
axes
=
plt
.
subplots
(
3
,
1
,
sharex
=
True
,
sharey
=
True
)
plot_spec
(
axes
[
0
],
spec
[
0
],
title
=
"Original"
)
plot_spec
(
axes
[
1
],
time_masked
[
0
],
title
=
"Masked along time axis"
)
...
...
src/torchaudio/transforms/_transforms.py
View file @
172260f9
...
...
@@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
Proposed in *SpecAugment* :cite:`specaugment`.
Args:
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
hop_length (int or None, optional): Length of hop between STFT windows.
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
.. note::
The expected input is raw, complex-valued spectrogram.
Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
>>> spectrogram = torchaudio.transforms.Spectrogram(
power=None
)
>>> stretch = torchaudio.transforms.TimeStretch()
>>>
>>> original = spectrogram(waveform)
>>> streched_1_2 = stretch(original, 1.2)
>>> streched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
:width: 600
:alt: Spectrogram streched by 1.2
>>> stretched_1_2 = stretch(original, 1.2)
>>> stretched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch
_2
.png
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
:width: 600
:alt: The original spectrogram
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
:width: 600
:alt: Spectrogram streched by 0.9
:alt: The visualization of stretched spectrograms.
"""
__constants__
=
[
"fixed_rate"
]
...
...
@@ -1067,8 +1063,8 @@ class TimeStretch(torch.nn.Module):
Returns:
Tensor:
Stretched spectrogram. The resulting tensor is of the
same dtype as the input
spectrogram,
but
the number of frames is changed to ``ceil(num_frame / rate)``.
Stretched spectrogram. The resulting tensor is of the
corresponding complex dtype
as the input
spectrogram,
and
the number of frames is changed to ``ceil(num_frame / rate)``.
"""
if
overriding_rate
is
None
:
if
self
.
fixed_rate
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment