Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
703614f2
Unverified
Commit
703614f2
authored
Feb 02, 2021
by
jieruan
Committed by
GitHub
Feb 02, 2021
Browse files
Expose normalization method to Mel transforms (#1212)
parent
a4c095a3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
torchaudio/transforms.py
torchaudio/transforms.py
+18
-7
No files found.
torchaudio/transforms.py
View file @
703614f2
...
@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
...
@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
"""
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
...
@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
...
@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
sample_rate
:
int
=
16000
,
sample_rate
:
int
=
16000
,
f_min
:
float
=
0.
,
f_min
:
float
=
0.
,
f_max
:
Optional
[
float
]
=
None
,
f_max
:
Optional
[
float
]
=
None
,
n_stft
:
Optional
[
int
]
=
None
)
->
None
:
n_stft
:
Optional
[
int
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelScale
,
self
).
__init__
()
super
(
MelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_min
=
f_min
self
.
f_min
=
f_min
self
.
norm
=
norm
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
...
@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
...
@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
if
self
.
fb
.
numel
()
==
0
:
if
self
.
fb
.
numel
()
==
0
:
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
# Attributes cannot be reassigned outside __init__ so workaround
# Attributes cannot be reassigned outside __init__ so workaround
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
copy_
(
tmp_fb
)
self
.
fb
.
copy_
(
tmp_fb
)
...
@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
"""
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
'tolerance_change'
,
'sgdargs'
]
'tolerance_change'
,
'sgdargs'
]
...
@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
max_iter
:
int
=
100000
,
max_iter
:
int
=
100000
,
tolerance_loss
:
float
=
1e-5
,
tolerance_loss
:
float
=
1e-5
,
tolerance_change
:
float
=
1e-8
,
tolerance_change
:
float
=
1e-8
,
sgdargs
:
Optional
[
dict
]
=
None
)
->
None
:
sgdargs
:
Optional
[
dict
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
InverseMelScale
,
self
).
__init__
()
super
(
InverseMelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
...
@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
...
@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
...
@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Example
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
wkwargs
:
Optional
[
dict
]
=
None
,
wkwargs
:
Optional
[
dict
]
=
None
,
center
:
bool
=
True
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
pad_mode
:
str
=
"reflect"
,
onesided
:
bool
=
True
)
->
None
:
onesided
:
bool
=
True
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelSpectrogram
,
self
).
__init__
()
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
n_fft
=
n_fft
...
@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
,
norm
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment