Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
703614f2
"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "733e6ff84eecaf5f987ca368bee12274d4a8eb59"
Unverified
Commit
703614f2
authored
Feb 02, 2021
by
jieruan
Committed by
GitHub
Feb 02, 2021
Browse files
Expose normalization method to Mel transforms (#1212)
parent
a4c095a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
torchaudio/transforms.py
torchaudio/transforms.py
+18
-7
No files found.
torchaudio/transforms.py
View file @
703614f2
...
@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
...
@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
"""
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
...
@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
...
@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
sample_rate
:
int
=
16000
,
sample_rate
:
int
=
16000
,
f_min
:
float
=
0.
,
f_min
:
float
=
0.
,
f_max
:
Optional
[
float
]
=
None
,
f_max
:
Optional
[
float
]
=
None
,
n_stft
:
Optional
[
int
]
=
None
)
->
None
:
n_stft
:
Optional
[
int
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelScale
,
self
).
__init__
()
super
(
MelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_min
=
f_min
self
.
f_min
=
f_min
self
.
norm
=
norm
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
...
@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
...
@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
if
self
.
fb
.
numel
()
==
0
:
if
self
.
fb
.
numel
()
==
0
:
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
# Attributes cannot be reassigned outside __init__ so workaround
# Attributes cannot be reassigned outside __init__ so workaround
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
copy_
(
tmp_fb
)
self
.
fb
.
copy_
(
tmp_fb
)
...
@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
"""
"""
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
'tolerance_change'
,
'sgdargs'
]
'tolerance_change'
,
'sgdargs'
]
...
@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
max_iter
:
int
=
100000
,
max_iter
:
int
=
100000
,
tolerance_loss
:
float
=
1e-5
,
tolerance_loss
:
float
=
1e-5
,
tolerance_change
:
float
=
1e-8
,
tolerance_change
:
float
=
1e-8
,
sgdargs
:
Optional
[
dict
]
=
None
)
->
None
:
sgdargs
:
Optional
[
dict
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
InverseMelScale
,
self
).
__init__
()
super
(
InverseMelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
...
@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
...
@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
norm
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
...
@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Example
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
wkwargs
:
Optional
[
dict
]
=
None
,
wkwargs
:
Optional
[
dict
]
=
None
,
center
:
bool
=
True
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
pad_mode
:
str
=
"reflect"
,
onesided
:
bool
=
True
)
->
None
:
onesided
:
bool
=
True
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
super
(
MelSpectrogram
,
self
).
__init__
()
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
n_fft
=
n_fft
...
@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
,
norm
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment