Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
114461cc
"docs/vscode:/vscode.git/clone" did not exist on "b036725166ffc2803c0d4156f7abd859b45e92b3"
Unverified
Commit
114461cc
authored
Jan 29, 2021
by
jieruan
Committed by
GitHub
Jan 29, 2021
Browse files
Expose stft arguments to MelSpectrogram (#1211)
parent
af1e457e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
2 deletions
+29
-2
test/torchaudio_unittest/transforms_test.py
test/torchaudio_unittest/transforms_test.py
+16
-0
torchaudio/transforms.py
torchaudio/transforms.py
+13
-2
No files found.
test/torchaudio_unittest/transforms_test.py
View file @
114461cc
...
...
@@ -219,3 +219,19 @@ class Tester(common_utils.TorchaudioTestCase):
computed
=
transform
(
specgram
)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-6
,
rtol
=
1e-8
)
class
SmokeTest
(
common_utils
.
TorchaudioTestCase
):
def
test_spectrogram
(
self
):
specgram
=
transforms
.
Spectrogram
(
center
=
False
,
pad_mode
=
"reflect"
,
onesided
=
False
)
self
.
assertEqual
(
specgram
.
center
,
False
)
self
.
assertEqual
(
specgram
.
pad_mode
,
"reflect"
)
self
.
assertEqual
(
specgram
.
onesided
,
False
)
def
test_melspectrogram
(
self
):
melspecgram
=
transforms
.
MelSpectrogram
(
center
=
True
,
pad_mode
=
"reflect"
,
onesided
=
False
)
specgram
=
melspecgram
.
spectrogram
self
.
assertEqual
(
specgram
.
center
,
True
)
self
.
assertEqual
(
specgram
.
pad_mode
,
"reflect"
)
self
.
assertEqual
(
specgram
.
onesided
,
False
)
torchaudio/transforms.py
View file @
114461cc
...
...
@@ -411,6 +411,13 @@ class MelSpectrogram(torch.nn.Module):
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
...
@@ -430,7 +437,10 @@ class MelSpectrogram(torch.nn.Module):
window_fn
:
Callable
[...,
Tensor
]
=
torch
.
hann_window
,
power
:
Optional
[
float
]
=
2.
,
normalized
:
bool
=
False
,
wkwargs
:
Optional
[
dict
]
=
None
)
->
None
:
wkwargs
:
Optional
[
dict
]
=
None
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
onesided
:
bool
=
True
)
->
None
:
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
...
...
@@ -445,7 +455,8 @@ class MelSpectrogram(torch.nn.Module):
self
.
spectrogram
=
Spectrogram
(
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
)
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment