Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
8b616bce
Commit
8b616bce
authored
Jan 05, 2019
by
David Pollack
Committed by
Soumith Chintala
Jan 05, 2019
Browse files
fix MEL2 and update filterbank conversion matrix
parent
a8d6a41b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
31 deletions
+56
-31
test/test_transforms.py
test/test_transforms.py
+20
-2
torchaudio/transforms.py
torchaudio/transforms.py
+36
-29
No files found.
test/test_transforms.py
View file @
8b616bce
...
@@ -155,13 +155,26 @@ class Tester(unittest.TestCase):
...
@@ -155,13 +155,26 @@ class Tester(unittest.TestCase):
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
mel_transform
=
transforms
.
MEL2
(
window
=
torch
.
hamming_window
,
pad
=
10
)
mel_transform
=
transforms
.
MEL2
()
# check defaults
spectrogram_torch
=
mel_transform
(
audio_scaled
)
# (1, 319, 40)
spectrogram_torch
=
mel_transform
(
audio_scaled
)
# (1, 319, 40)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
# load stereo file
# check correctness of filterbank conversion matrix
self
.
assertTrue
(
mel_transform
.
fm
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform
.
fm
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check options
mel_transform2
=
transforms
.
MEL2
(
window
=
torch
.
hamming_window
,
pad
=
10
,
ws
=
500
,
hop
=
125
,
n_fft
=
800
,
n_mels
=
50
)
spectrogram2_torch
=
mel_transform2
(
audio_scaled
)
# (1, 506, 50)
self
.
assertTrue
(
spectrogram2_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram2_torch
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram2_torch
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram2_torch
.
size
(
-
1
),
mel_transform2
.
n_mels
)
self
.
assertTrue
(
mel_transform2
.
fm
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform2
.
fm
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check on multi-channel audio
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
spectrogram_stereo
=
mel_transform
(
x_stereo
)
spectrogram_stereo
=
mel_transform
(
x_stereo
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
...
@@ -169,6 +182,11 @@ class Tester(unittest.TestCase):
...
@@ -169,6 +182,11 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
spectrogram_stereo
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_stereo
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_stereo
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertTrue
(
spectrogram_stereo
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
# check filterbank matrix creation
fb_matrix_transform
=
transforms
.
F2M
(
n_mels
=
100
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_stft
=
400
)
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/transforms.py
View file @
8b616bce
...
@@ -160,23 +160,22 @@ class SPECTROGRAM(object):
...
@@ -160,23 +160,22 @@ class SPECTROGRAM(object):
Args:
Args:
sr (int): sample rate of audio signal
sr (int): sample rate of audio signal
ws (int): window size
, often called the fft size as well
ws (int): window size
hop (int, optional): length of hop between STFT windows. default: ws // 2
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional):
number
of fft bins. default: ws
// 2 + 1
n_fft (int, optional):
size
of fft
, creates n_fft // 2 + 1
bins. default: ws
pad (int): two sided padding of signal
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
wkwargs (dict, optional): arguments for window function
"""
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
def
__init__
(
self
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
pad
=
0
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
sr
=
sr
self
.
ws
=
ws
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
# number of frequecies due to onesided=True in torch.stft
self
.
n_fft
=
(
n_fft
-
1
)
*
2
if
n_fft
is
not
None
else
ws
self
.
n_fft
=
n_fft
if
n_fft
is
not
None
else
ws
self
.
pad
=
pad
self
.
pad
=
pad
self
.
wkwargs
=
wkwargs
self
.
wkwargs
=
wkwargs
...
@@ -212,17 +211,17 @@ class F2M(object):
...
@@ -212,17 +211,17 @@ class F2M(object):
Args:
Args:
n_mels (int): number of MEL bins
n_mels (int): number of MEL bins
sr (int): sample rate of audio signal
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_max (float, optional): maximum frequency. default:
`
sr
`
// 2
f_min (float): minimum frequency. default: 0
f_min (float): minimum frequency. default: 0
n_
f
ft (int, optional): number of filter banks from stft. Calculated from first input
n_
st
ft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
if `None` is given.
See `n_fft` in `SPECTROGRAM`.
"""
"""
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_
f
ft
=
None
):
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_
st
ft
=
None
):
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sr
=
sr
self
.
sr
=
sr
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_min
=
f_min
self
.
f_min
=
f_min
self
.
fb
=
self
.
_create_fb_matrix
(
n_
f
ft
)
if
n_
f
ft
is
not
None
else
n_
f
ft
self
.
fb
=
self
.
_create_fb_matrix
(
n_
st
ft
)
if
n_
st
ft
is
not
None
else
n_
st
ft
def
__call__
(
self
,
spec_f
):
def
__call__
(
self
,
spec_f
):
if
self
.
fb
is
None
:
if
self
.
fb
is
None
:
...
@@ -230,27 +229,35 @@ class F2M(object):
...
@@ -230,27 +229,35 @@ class F2M(object):
spec_m
=
torch
.
matmul
(
spec_f
,
self
.
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
spec_m
=
torch
.
matmul
(
spec_f
,
self
.
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
return
spec_m
def
_create_fb_matrix
(
self
,
n_
f
ft
):
def
_create_fb_matrix
(
self
,
n_
st
ft
):
""" Create a frequency bin conversion matrix.
""" Create a frequency bin conversion matrix.
Args:
Args:
n_
f
ft (int): number of filter banks from spectrogram
n_
st
ft (int): number of filter banks from spectrogram
"""
"""
m_min
=
0.
if
self
.
f_min
==
0
else
2595
*
np
.
log10
(
1.
+
(
self
.
f_min
/
700
))
# get stft freq bins
m_max
=
2595
*
np
.
log10
(
1.
+
(
self
.
f_max
/
700
))
stft_freqs
=
torch
.
linspace
(
self
.
f_min
,
self
.
f_max
,
n_stft
)
# calculate mel freq bins
m_min
=
0.
if
self
.
f_min
==
0
else
self
.
_hertz_to_mel
(
self
.
f_min
)
m_max
=
self
.
_hertz_to_mel
(
self
.
f_max
)
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
self
.
n_mels
+
2
)
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
self
.
n_mels
+
2
)
f_pts
=
(
700
*
(
10
**
(
m_pts
/
2595
)
-
1
))
f_pts
=
self
.
_mel_to_hertz
(
m_pts
)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_mels + 1)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
stft_freqs
.
unsqueeze
(
1
)
# (n_stft, n_mels + 2)
# create overlapping triangles
z
=
torch
.
tensor
(
0.
)
down_slopes
=
(
-
1.
*
slopes
[:,
:
-
2
])
/
f_diff
[:
-
1
]
# (n_stft, n_mels)
up_slopes
=
slopes
[:,
2
:]
/
f_diff
[
1
:]
# (n_stft, n_mels)
fb
=
torch
.
max
(
z
,
torch
.
min
(
down_slopes
,
up_slopes
))
return
fb
bins
=
torch
.
floor
(((
n_fft
-
1
)
*
2
)
*
f_pts
/
self
.
sr
).
long
()
def
_hertz_to_mel
(
self
,
f
):
return
2595.
*
torch
.
log10
(
torch
.
tensor
(
1.
)
+
(
f
/
700.
))
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
,
dtype
=
torch
.
float
)
def
_mel_to_hertz
(
self
,
mel
):
for
m
in
range
(
1
,
self
.
n_mels
+
1
):
return
700.
*
(
10
**
(
mel
/
2595.
)
-
1.
)
f_m_minus
=
bins
[
m
-
1
].
item
()
f_m_plus
=
bins
[
m
+
1
].
item
()
fb
[
f_m_minus
:
f_m_plus
,
m
-
1
]
=
torch
.
bartlett_window
(
f_m_plus
-
f_m_minus
)
return
fb
class
SPEC2DB
(
object
):
class
SPEC2DB
(
object
):
...
@@ -287,12 +294,12 @@ class MEL2(object):
...
@@ -287,12 +294,12 @@ class MEL2(object):
Args:
Args:
sr (int): sample rate of audio signal
sr (int): sample rate of audio signal
ws (int): window size
, often called the fft size as well
ws (int): window size
hop (int, optional): length of hop between STFT windows. default: ws // 2
hop (int, optional): length of hop between STFT windows. default:
`
ws
`
// 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
n_fft (int, optional): number of fft bins. default:
`
ws
`
// 2 + 1
pad (int): two sided padding of signal
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window
window (torch windowing function): default:
`
torch.hann_window
`
wkwargs (dict, optional): arguments for window function
wkwargs (dict, optional): arguments for window function
Example:
Example:
...
@@ -312,9 +319,9 @@ class MEL2(object):
...
@@ -312,9 +319,9 @@ class MEL2(object):
self
.
top_db
=
-
80.
self
.
top_db
=
-
80.
self
.
f_max
=
None
self
.
f_max
=
None
self
.
f_min
=
0.
self
.
f_min
=
0.
self
.
spec
=
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
spec
=
SPECTROGRAM
(
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window
,
self
.
wkwargs
)
self
.
pad
,
self
.
window
,
self
.
wkwargs
)
self
.
fm
=
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
,
self
.
n_fft
)
self
.
fm
=
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
)
self
.
s2db
=
SPEC2DB
(
"power"
,
self
.
top_db
)
self
.
s2db
=
SPEC2DB
(
"power"
,
self
.
top_db
)
self
.
transforms
=
Compose
([
self
.
transforms
=
Compose
([
self
.
spec
,
self
.
fm
,
self
.
s2db
,
self
.
spec
,
self
.
fm
,
self
.
s2db
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment