Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a8d6a41b
Commit
a8d6a41b
authored
Jan 04, 2019
by
David Pollack
Committed by
Soumith Chintala
Jan 04, 2019
Browse files
optimization to MEL2 and fixes to filter bank conversion function
parent
3bd4db86
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
44 deletions
+64
-44
test/test_transforms.py
test/test_transforms.py
+30
-13
torchaudio/transforms.py
torchaudio/transforms.py
+34
-31
No files found.
test/test_transforms.py
View file @
a8d6a41b
from
__future__
import
print_function
import
os
import
torch
import
torchaudio
import
torchaudio.transforms
as
transforms
...
...
@@ -8,13 +9,17 @@ import unittest
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
sr
=
16000
freq
=
440
volume
=
.
3
sig
=
(
torch
.
cos
(
2
*
np
.
pi
*
torch
.
arange
(
0
,
4
*
sr
).
float
()
*
freq
/
sr
))
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig
.
unsqueeze_
(
1
)
sig
.
unsqueeze_
(
1
)
# (64000, 1)
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
# file for stereo stft test
test_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"assets"
,
"steam-train-whistle-daniel_simon.mp3"
)
def
test_scale
(
self
):
...
...
@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
repr_test
=
transforms
.
Scale
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_pad_trim
(
self
):
...
...
@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_downmix_mono
(
self
):
...
...
@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
repr_test
=
transforms
.
DownmixMono
(
channels_first
=
False
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_lc2cl
(
self
):
...
...
@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
repr_test
=
transforms
.
LC2CL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mel
(
self
):
...
...
@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
dim
()
==
3
)
repr_test
=
transforms
.
MEL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
())
repr_test
=
transforms
.
BLC2CBL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_compose
(
self
):
...
...
@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
repr_test
=
transforms
.
Compose
(
tset
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mu_law_companding
(
self
):
...
...
@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mel2
(
self
):
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
spectrogram_torch
=
transforms
.
MEL2
(
window_fn
=
torch
.
hamming_window
,
pad
=
10
)(
audio_scaled
)
# (1, 319, 40)
mel_transform
=
transforms
.
MEL2
(
window
=
torch
.
hamming_window
,
pad
=
10
)
spectrogram_torch
=
mel_transform
(
audio_scaled
)
# (1, 319, 40)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
max
()
<=
0.
)
self
.
assertTrue
(
spectrogram_torch
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
# load stereo file
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
spectrogram_stereo
=
mel_transform
(
x_stereo
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_stereo
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_stereo
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchaudio/transforms.py
View file @
a8d6a41b
...
...
@@ -60,7 +60,7 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if
not
tensor
.
is_floating_point
()
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
return
tensor
/
self
.
factor
...
...
@@ -125,7 +125,7 @@ class DownmixMono(object):
self
.
ch_dim
=
int
(
not
channels_first
)
def
__call__
(
self
,
tensor
):
if
not
tensor
.
is_floating_point
()
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
torch
.
mean
(
tensor
,
self
.
ch_dim
,
True
)
...
...
@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
window
_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
_fn
(
ws
)
if
wkwargs
is
None
else
window
_fn
(
ws
,
**
wkwargs
)
pad
=
0
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
sr
=
sr
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
...
...
@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
if
self
.
pad
>
0
:
with
torch
.
no_grad
():
sig
=
torch
.
nn
.
functional
.
pad
(
sig
,
(
self
.
pad
,
self
.
pad
),
"constant"
)
spec_f
=
torch
.
stft
(
sig
,
self
.
n_fft
,
self
.
hop
,
self
.
ws
,
self
.
window
,
center
=
False
,
normalized
=
True
,
onesided
=
True
).
transpose
(
1
,
2
)
...
...
@@ -215,16 +214,28 @@ class F2M(object):
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
"""
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
):
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_fft
=
None
):
self
.
n_mels
=
n_mels
self
.
sr
=
sr
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_min
=
f_min
self
.
fb
=
self
.
_create_fb_matrix
(
n_fft
)
if
n_fft
is
not
None
else
n_fft
def
__call__
(
self
,
spec_f
):
if
self
.
fb
is
None
:
self
.
fb
=
self
.
_create_fb_matrix
(
spec_f
.
size
(
2
))
spec_m
=
torch
.
matmul
(
spec_f
,
self
.
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
def
_create_fb_matrix
(
self
,
n_fft
):
""" Create a frequency bin conversion matrix.
n_fft
=
spec_f
.
size
(
2
)
Args:
n_fft (int): number of filter banks from spectrogram
"""
m_min
=
0.
if
self
.
f_min
==
0
else
2595
*
np
.
log10
(
1.
+
(
self
.
f_min
/
700
))
m_max
=
2595
*
np
.
log10
(
1.
+
(
self
.
f_max
/
700
))
...
...
@@ -234,19 +245,12 @@ class F2M(object):
bins
=
torch
.
floor
(((
n_fft
-
1
)
*
2
)
*
f_pts
/
self
.
sr
).
long
()
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
)
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
,
dtype
=
torch
.
float
)
for
m
in
range
(
1
,
self
.
n_mels
+
1
):
f_m_minus
=
bins
[
m
-
1
].
item
()
f_m
=
bins
[
m
].
item
()
f_m_plus
=
bins
[
m
+
1
].
item
()
if
f_m_minus
!=
f_m
:
fb
[
f_m_minus
:
f_m
,
m
-
1
]
=
(
torch
.
arange
(
f_m_minus
,
f_m
)
-
f_m_minus
)
/
(
f_m
-
f_m_minus
)
if
f_m
!=
f_m_plus
:
fb
[
f_m
:
f_m_plus
,
m
-
1
]
=
(
f_m_plus
-
torch
.
arange
(
f_m
,
f_m_plus
))
/
(
f_m_plus
-
f_m
)
spec_m
=
torch
.
matmul
(
spec_f
,
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
fb
[
f_m_minus
:
f_m_plus
,
m
-
1
]
=
torch
.
bartlett_window
(
f_m_plus
-
f_m_minus
)
return
fb
class
SPEC2DB
(
object
):
...
...
@@ -267,7 +271,7 @@ class SPEC2DB(object):
spec_db
=
self
.
multiplier
*
torch
.
log10
(
spec
/
spec
.
max
())
# power -> dB
if
self
.
top_db
is
not
None
:
spec_db
=
torch
.
max
(
spec_db
,
spec_db
.
new
([
self
.
top_db
]
))
spec_db
=
torch
.
max
(
spec_db
,
torch
.
tensor
(
self
.
top_db
,
dtype
=
spec_db
.
dtype
))
return
spec_db
...
...
@@ -296,8 +300,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
n_mels
=
40
,
window
_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
_fn
=
window
_fn
pad
=
0
,
n_mels
=
40
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
self
.
sr
=
sr
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
...
...
@@ -308,6 +312,13 @@ class MEL2(object):
self
.
top_db
=
-
80.
self
.
f_max
=
None
self
.
f_min
=
0.
self
.
spec
=
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window
,
self
.
wkwargs
)
self
.
fm
=
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
,
self
.
n_fft
)
self
.
s2db
=
SPEC2DB
(
"power"
,
self
.
top_db
)
self
.
transforms
=
Compose
([
self
.
spec
,
self
.
fm
,
self
.
s2db
,
])
def
__call__
(
self
,
sig
):
"""
...
...
@@ -320,15 +331,7 @@ class MEL2(object):
number of mel bins.
"""
transforms
=
Compose
([
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window_fn
,
self
.
wkwargs
),
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
),
SPEC2DB
(
"power"
,
self
.
top_db
),
])
spec_mel_db
=
transforms
(
sig
)
spec_mel_db
=
self
.
transforms
(
sig
)
return
spec_mel_db
...
...
@@ -426,7 +429,7 @@ class MuLawEncoding(object):
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
astype
(
int
)
elif
isinstance
(
x
,
torch
.
Tensor
):
if
not
x
.
is_floating_point
()
:
if
not
x
.
dtype
.
is_floating_point
:
x
=
x
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
...
...
@@ -468,7 +471,7 @@ class MuLawExpanding(object):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
np
.
sign
(
x
)
*
(
np
.
exp
(
np
.
abs
(
x
)
*
np
.
log1p
(
mu
))
-
1.
)
/
mu
elif
isinstance
(
x_mu
,
torch
.
Tensor
):
if
not
x_mu
.
is_floating_point
()
:
if
not
x_mu
.
dtype
.
is_floating_point
:
x_mu
=
x_mu
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment