Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
5521f6c7
Unverified
Commit
5521f6c7
authored
Mar 02, 2021
by
Vincent QB
Committed by
GitHub
Mar 02, 2021
Browse files
add mel_scale option for slaney/htk (#593)
parent
ecfed4d9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
129 additions
and
24 deletions
+129
-24
test/torchaudio_unittest/functional/librosa_compatibility_test.py
...chaudio_unittest/functional/librosa_compatibility_test.py
+19
-3
test/torchaudio_unittest/librosa_compatibility_test.py
test/torchaudio_unittest/librosa_compatibility_test.py
+10
-8
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+75
-6
torchaudio/transforms.py
torchaudio/transforms.py
+25
-7
No files found.
test/torchaudio_unittest/functional/librosa_compatibility_test.py
View file @
5521f6c7
...
@@ -46,20 +46,29 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -46,20 +46,29 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self
.
assertEqual
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
def
_test_create_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
):
def
_test_create_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
,
mel_scale
=
"htk"
,
):
librosa_fb
=
librosa
.
filters
.
mel
(
sr
=
sample_rate
,
librosa_fb
=
librosa
.
filters
.
mel
(
sr
=
sample_rate
,
n_fft
=
n_fft
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
n_mels
=
n_mels
,
fmax
=
fmax
,
fmax
=
fmax
,
fmin
=
fmin
,
fmin
=
fmin
,
htk
=
True
,
htk
=
mel_scale
==
"htk"
,
norm
=
norm
)
norm
=
norm
)
fb
=
F
.
create_fb_matrix
(
sample_rate
=
sample_rate
,
fb
=
F
.
create_fb_matrix
(
sample_rate
=
sample_rate
,
n_mels
=
n_mels
,
n_mels
=
n_mels
,
f_max
=
fmax
,
f_max
=
fmax
,
f_min
=
fmin
,
f_min
=
fmin
,
n_freqs
=
(
n_fft
//
2
+
1
),
n_freqs
=
(
n_fft
//
2
+
1
),
norm
=
norm
)
norm
=
norm
,
mel_scale
=
mel_scale
)
for
i_mel_bank
in
range
(
n_mels
):
for
i_mel_bank
in
range
(
n_mels
):
self
.
assertEqual
(
self
.
assertEqual
(
...
@@ -73,6 +82,13 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -73,6 +82,13 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
800.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
800.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
1900.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
1900.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
n_mels
=
10
,
fmin
=
1900.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
n_mels
=
10
,
fmin
=
1900.0
,
fmax
=
900.0
)
self
.
_test_create_fb
(
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
128
,
sample_rate
=
44100
,
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
128
,
fmin
=
2000.0
,
fmax
=
5000.0
,
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
100.0
,
fmax
=
9000.0
,
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
800.0
,
fmax
=
900.0
,
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
56
,
fmin
=
1900.0
,
fmax
=
900.0
,
mel_scale
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
10
,
fmin
=
1900.0
,
fmax
=
900.0
,
mel_scale
=
"slaney"
)
if
StrictVersion
(
librosa
.
__version__
)
<
StrictVersion
(
"0.7.2"
):
if
StrictVersion
(
librosa
.
__version__
)
<
StrictVersion
(
"0.7.2"
):
return
return
self
.
_test_create_fb
(
n_mels
=
128
,
sample_rate
=
44100
,
norm
=
"slaney"
)
self
.
_test_create_fb
(
n_mels
=
128
,
sample_rate
=
44100
,
norm
=
"slaney"
)
...
...
test/torchaudio_unittest/librosa_compatibility_test.py
View file @
5521f6c7
...
@@ -46,31 +46,32 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -46,31 +46,32 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self
.
assertEqual
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
,
rtol
=
1e-5
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
param
(
norm
=
norm
,
**
p
.
kwargs
)
param
(
norm
=
norm
,
mel_scale
=
mel_scale
,
**
p
.
kwargs
)
for
p
in
[
for
p
in
[
param
(
n_fft
=
400
,
hop_length
=
200
,
n_mels
=
128
),
param
(
n_fft
=
400
,
hop_length
=
200
,
n_mels
=
128
),
param
(
n_fft
=
600
,
hop_length
=
100
,
n_mels
=
128
),
param
(
n_fft
=
600
,
hop_length
=
100
,
n_mels
=
128
),
param
(
n_fft
=
200
,
hop_length
=
50
,
n_mels
=
128
),
param
(
n_fft
=
200
,
hop_length
=
50
,
n_mels
=
128
),
]
]
for
norm
in
[
None
,
'slaney'
]
for
norm
in
[
None
,
'slaney'
]
for
mel_scale
in
[
'htk'
,
'slaney'
]
])
])
def
test_mel_spectrogram
(
self
,
n_fft
,
hop_length
,
n_mels
,
norm
):
def
test_mel_spectrogram
(
self
,
n_fft
,
hop_length
,
n_mels
,
norm
,
mel_scale
):
sample_rate
=
16000
sample_rate
=
16000
sound
=
common_utils
.
get_sinusoid
(
n_channels
=
1
,
sample_rate
=
sample_rate
)
sound
=
common_utils
.
get_sinusoid
(
n_channels
=
1
,
sample_rate
=
sample_rate
)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
norm
=
norm
)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
norm
=
norm
,
mel_scale
=
mel_scale
)
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
norm
)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
mel_scale
==
"htk"
,
norm
=
norm
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertEqual
(
self
.
assertEqual
(
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
param
(
norm
=
norm
,
**
p
.
kwargs
)
param
(
norm
=
norm
,
mel_scale
=
mel_scale
,
**
p
.
kwargs
)
for
p
in
[
for
p
in
[
param
(
n_fft
=
400
,
hop_length
=
200
,
power
=
2.0
,
n_mels
=
128
),
param
(
n_fft
=
400
,
hop_length
=
200
,
power
=
2.0
,
n_mels
=
128
),
param
(
n_fft
=
600
,
hop_length
=
100
,
power
=
2.0
,
n_mels
=
128
),
param
(
n_fft
=
600
,
hop_length
=
100
,
power
=
2.0
,
n_mels
=
128
),
...
@@ -79,8 +80,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -79,8 +80,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
param
(
n_fft
=
200
,
hop_length
=
50
,
power
=
2.0
,
n_mels
=
128
,
skip_ci
=
True
),
param
(
n_fft
=
200
,
hop_length
=
50
,
power
=
2.0
,
n_mels
=
128
,
skip_ci
=
True
),
]
]
for
norm
in
[
None
,
'slaney'
]
for
norm
in
[
None
,
'slaney'
]
for
mel_scale
in
[
'htk'
,
'slaney'
]
])
])
def
test_s2db
(
self
,
n_fft
,
hop_length
,
power
,
n_mels
,
norm
,
skip_ci
=
False
):
def
test_s2db
(
self
,
n_fft
,
hop_length
,
power
,
n_mels
,
norm
,
mel_scale
,
skip_ci
=
False
):
if
skip_ci
and
'CI'
in
os
.
environ
:
if
skip_ci
and
'CI'
in
os
.
environ
:
self
.
skipTest
(
'Test is known to fail on CI'
)
self
.
skipTest
(
'Test is known to fail on CI'
)
sample_rate
=
16000
sample_rate
=
16000
...
@@ -92,10 +94,10 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -92,10 +94,10 @@ class TestTransforms(common_utils.TorchaudioTestCase):
y
=
sound_librosa
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
power
=
power
)
y
=
sound_librosa
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
power
=
power
)
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
norm
=
norm
)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
norm
=
norm
,
mel_scale
=
mel_scale
)
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
norm
)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
mel_scale
==
"htk"
,
norm
=
norm
)
power_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
power_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
power_to_db_torch
=
power_to_db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
power_to_db_torch
=
power_to_db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
...
...
torchaudio/functional/functional.py
View file @
5521f6c7
...
@@ -296,13 +296,81 @@ def DB_to_amplitude(
...
@@ -296,13 +296,81 @@ def DB_to_amplitude(
return
ref
*
torch
.
pow
(
torch
.
pow
(
10.0
,
0.1
*
x
),
power
)
return
ref
*
torch
.
pow
(
torch
.
pow
(
10.0
,
0.1
*
x
),
power
)
def
_hz_to_mel
(
freq
:
float
,
mel_scale
:
str
=
"htk"
)
->
float
:
r
"""Convert Hz to Mels.
Args:
freqs (float): Frequencies in Hz
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
mels (float): Frequency in Mels
"""
if
mel_scale
not
in
[
'slaney'
,
'htk'
]:
raise
ValueError
(
'mel_scale should be one of "htk" or "slaney".'
)
if
mel_scale
==
"htk"
:
return
2595.0
*
math
.
log10
(
1.0
+
(
freq
/
700.0
))
# Fill in the linear part
f_min
=
0.0
f_sp
=
200.0
/
3
mels
=
(
freq
-
f_min
)
/
f_sp
# Fill in the log-scale part
min_log_hz
=
1000.0
min_log_mel
=
(
min_log_hz
-
f_min
)
/
f_sp
logstep
=
math
.
log
(
6.4
)
/
27.0
if
freq
>=
min_log_hz
:
mels
=
min_log_mel
+
math
.
log
(
freq
/
min_log_hz
)
/
logstep
return
mels
def
_mel_to_hz
(
mels
:
Tensor
,
mel_scale
:
str
=
"htk"
)
->
Tensor
:
"""Convert mel bin numbers to frequencies.
Args:
mels (Tensor): Mel frequencies
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
freqs (Tensor): Mels converted in Hz
"""
if
mel_scale
not
in
[
'slaney'
,
'htk'
]:
raise
ValueError
(
'mel_scale should be one of "htk" or "slaney".'
)
if
mel_scale
==
"htk"
:
return
700.0
*
(
10.0
**
(
mels
/
2595.0
)
-
1.0
)
# Fill in the linear scale
f_min
=
0.0
f_sp
=
200.0
/
3
freqs
=
f_min
+
f_sp
*
mels
# And now the nonlinear scale
min_log_hz
=
1000.0
min_log_mel
=
(
min_log_hz
-
f_min
)
/
f_sp
logstep
=
math
.
log
(
6.4
)
/
27.0
log_t
=
(
mels
>=
min_log_mel
)
freqs
[
log_t
]
=
min_log_hz
*
torch
.
exp
(
logstep
*
(
mels
[
log_t
]
-
min_log_mel
))
return
freqs
def
create_fb_matrix
(
def
create_fb_matrix
(
n_freqs
:
int
,
n_freqs
:
int
,
f_min
:
float
,
f_min
:
float
,
f_max
:
float
,
f_max
:
float
,
n_mels
:
int
,
n_mels
:
int
,
sample_rate
:
int
,
sample_rate
:
int
,
norm
:
Optional
[
str
]
=
None
norm
:
Optional
[
str
]
=
None
,
mel_scale
:
str
=
"htk"
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Create a frequency bin conversion matrix.
r
"""Create a frequency bin conversion matrix.
...
@@ -314,6 +382,7 @@ def create_fb_matrix(
...
@@ -314,6 +382,7 @@ def create_fb_matrix(
sample_rate (int): Sample rate of the audio waveform
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
...
@@ -331,12 +400,12 @@ def create_fb_matrix(
...
@@ -331,12 +400,12 @@ def create_fb_matrix(
all_freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
n_freqs
)
all_freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
n_freqs
)
# calculate mel freq bins
# calculate mel freq bins
# hertz
to
mel(f
) is 2595. * math.log10(1. + (f / 700.)
)
m_min
=
_hz_
to
_
mel
(
f
_min
,
mel_scale
=
mel_scale
)
m_m
in
=
2595.0
*
math
.
log10
(
1.0
+
(
f_min
/
700.0
)
)
m_m
ax
=
_hz_to_mel
(
f_max
,
mel_scale
=
mel_scale
)
m_max
=
2595.0
*
math
.
log10
(
1.0
+
(
f_max
/
700.0
))
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
n_mels
+
2
)
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
n_mels
+
2
)
#
mel
to
hertz(mel) is 700. * (10**(mel / 2595.) - 1.
)
f_pts
=
_
mel
_
to
_hz
(
m_pts
,
mel_scale
=
mel_scale
)
f_pts
=
700.0
*
(
10
**
(
m_pts
/
2595.0
)
-
1.0
)
# calculate the difference between each mel point and each stft freq point in hertz
# calculate the difference between each mel point and each stft freq point in hertz
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_mels + 1)
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_mels + 1)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
all_freqs
.
unsqueeze
(
1
)
# (n_freqs, n_mels + 2)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
all_freqs
.
unsqueeze
(
1
)
# (n_freqs, n_mels + 2)
...
...
torchaudio/transforms.py
View file @
5521f6c7
...
@@ -249,6 +249,7 @@ class MelScale(torch.nn.Module):
...
@@ -249,6 +249,7 @@ class MelScale(torch.nn.Module):
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
"""
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
...
@@ -258,18 +259,21 @@ class MelScale(torch.nn.Module):
...
@@ -258,18 +259,21 @@ class MelScale(torch.nn.Module):
f_min
:
float
=
0.
,
f_min
:
float
=
0.
,
f_max
:
Optional
[
float
]
=
None
,
f_max
:
Optional
[
float
]
=
None
,
n_stft
:
Optional
[
int
]
=
None
,
n_stft
:
Optional
[
int
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
norm
:
Optional
[
str
]
=
None
,
mel_scale
:
str
=
"htk"
)
->
None
:
super
(
MelScale
,
self
).
__init__
()
super
(
MelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_min
=
f_min
self
.
f_min
=
f_min
self
.
norm
=
norm
self
.
norm
=
norm
self
.
mel_scale
=
mel_scale
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
,
self
.
mel_scale
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
specgram
:
Tensor
)
->
Tensor
:
...
@@ -287,7 +291,8 @@ class MelScale(torch.nn.Module):
...
@@ -287,7 +291,8 @@ class MelScale(torch.nn.Module):
if
self
.
fb
.
numel
()
==
0
:
if
self
.
fb
.
numel
()
==
0
:
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
tmp_fb
=
F
.
create_fb_matrix
(
specgram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
)
self
.
n_mels
,
self
.
sample_rate
,
self
.
norm
,
self
.
mel_scale
)
# Attributes cannot be reassigned outside __init__ so workaround
# Attributes cannot be reassigned outside __init__ so workaround
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
copy_
(
tmp_fb
)
self
.
fb
.
copy_
(
tmp_fb
)
...
@@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module):
...
@@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module):
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
"""
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
__constants__
=
[
'n_stft'
,
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
,
'max_iter'
,
'tolerance_loss'
,
'tolerance_change'
,
'sgdargs'
]
'tolerance_change'
,
'sgdargs'
]
...
@@ -335,7 +341,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -335,7 +341,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss
:
float
=
1e-5
,
tolerance_loss
:
float
=
1e-5
,
tolerance_change
:
float
=
1e-8
,
tolerance_change
:
float
=
1e-8
,
sgdargs
:
Optional
[
dict
]
=
None
,
sgdargs
:
Optional
[
dict
]
=
None
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
norm
:
Optional
[
str
]
=
None
,
mel_scale
:
str
=
"htk"
)
->
None
:
super
(
InverseMelScale
,
self
).
__init__
()
super
(
InverseMelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
...
@@ -348,7 +355,8 @@ class InverseMelScale(torch.nn.Module):
...
@@ -348,7 +355,8 @@ class InverseMelScale(torch.nn.Module):
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: {} < f_max: {}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
norm
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
,
norm
,
mel_scale
)
self
.
register_buffer
(
'fb'
,
fb
)
self
.
register_buffer
(
'fb'
,
fb
)
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
melspec
:
Tensor
)
->
Tensor
:
...
@@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module):
avoid redundancy. Default: ``True``
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Example
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
@@ -450,7 +459,8 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -450,7 +459,8 @@ class MelSpectrogram(torch.nn.Module):
center
:
bool
=
True
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
pad_mode
:
str
=
"reflect"
,
onesided
:
bool
=
True
,
onesided
:
bool
=
True
,
norm
:
Optional
[
str
]
=
None
)
->
None
:
norm
:
Optional
[
str
]
=
None
,
mel_scale
:
str
=
"htk"
)
->
None
:
super
(
MelSpectrogram
,
self
).
__init__
()
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
n_fft
=
n_fft
...
@@ -467,7 +477,15 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -467,7 +477,15 @@ class MelSpectrogram(torch.nn.Module):
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
self
.
power
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
normalized
=
self
.
normalized
,
wkwargs
=
wkwargs
,
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
center
=
center
,
pad_mode
=
pad_mode
,
onesided
=
onesided
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
,
norm
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
,
self
.
n_fft
//
2
+
1
,
norm
,
mel_scale
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment