Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
703614f2
Unverified
Commit
703614f2
authored
Feb 02, 2021
by
jieruan
Committed by
GitHub
Feb 02, 2021
Browse files
Expose normalization method to Mel transforms (#1212)
parent
a4c095a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
torchaudio/transforms.py
torchaudio/transforms.py
+18
-7
No files found.
torchaudio/transforms.py
View file @
703614f2
...
...
@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
...
...
@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
sample_rate
:
int
=
16000
,
f_min
:
float
=
0.
,
f_max
:
Optional
[
float
]
=
None
,
n_stft
:
Optional
[
int
]
=
None
)
->
None
:
n_stft
:
Optional
[
int
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_min
=
f_min
self
.
norm
=
norm
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
...
...
@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
if
self
.
fb
.
numel
()
==
0
:
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
# Attributes cannot be reassigned outside __init__ so workaround
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
copy_
(
tmp_fb
)
...
...
@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
'tolerance_change'
,
'sgdargs'
]
...
...
@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
max_iter
:
int
=
100000
,
tolerance_loss
:
float
=
1e-5
,
tolerance_change
:
float
=
1e-8
,
sgdargs
:
Optional
[
dict
]
=
None
)
->
None
:
sgdargs
:
Optional
[
dict
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
InverseMelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
...
...
@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
...
...
@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
...
@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
wkwargs
:
Optional
[
dict
]
=
None
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
onesided
:
bool
=
True
)
->
None
:
onesided
:
bool
=
True
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
...
...
@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
,
norm
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment