Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a8d6a41b
Commit
a8d6a41b
authored
Jan 04, 2019
by
David Pollack
Committed by
Soumith Chintala
Jan 04, 2019
Browse files
optimization to MEL2 and fixes to filter bank conversion function
parent
3bd4db86
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
44 deletions
+64
-44
test/test_transforms.py
test/test_transforms.py
+30
-13
torchaudio/transforms.py
torchaudio/transforms.py
+34
-31
No files found.
test/test_transforms.py
View file @
a8d6a41b
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
torch
import
torch
import
torchaudio
import
torchaudio
import
torchaudio.transforms
as
transforms
import
torchaudio.transforms
as
transforms
...
@@ -8,13 +9,17 @@ import unittest
...
@@ -8,13 +9,17 @@ import unittest
class
Tester
(
unittest
.
TestCase
):
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
sr
=
16000
sr
=
16000
freq
=
440
freq
=
440
volume
=
.
3
volume
=
.
3
sig
=
(
torch
.
cos
(
2
*
np
.
pi
*
torch
.
arange
(
0
,
4
*
sr
).
float
()
*
freq
/
sr
))
sig
=
(
torch
.
cos
(
2
*
np
.
pi
*
torch
.
arange
(
0
,
4
*
sr
).
float
()
*
freq
/
sr
))
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig
.
unsqueeze_
(
1
)
# (64000, 1)
sig
.
unsqueeze_
(
1
)
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
# file for stereo stft test
test_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"assets"
,
"steam-train-whistle-daniel_simon.mp3"
)
def
test_scale
(
self
):
def
test_scale
(
self
):
...
@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
...
@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
repr_test
=
transforms
.
Scale
()
repr_test
=
transforms
.
Scale
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_pad_trim
(
self
):
def
test_pad_trim
(
self
):
...
@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
...
@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_downmix_mono
(
self
):
def
test_downmix_mono
(
self
):
...
@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
...
@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
repr_test
=
transforms
.
DownmixMono
(
channels_first
=
False
)
repr_test
=
transforms
.
DownmixMono
(
channels_first
=
False
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_lc2cl
(
self
):
def
test_lc2cl
(
self
):
...
@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
...
@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
repr_test
=
transforms
.
LC2CL
()
repr_test
=
transforms
.
LC2CL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mel
(
self
):
def
test_mel
(
self
):
...
@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
...
@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
dim
()
==
3
)
self
.
assertTrue
(
result
.
dim
()
==
3
)
repr_test
=
transforms
.
MEL
()
repr_test
=
transforms
.
MEL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
())
repr_test
=
transforms
.
BLC2CBL
()
repr_test
=
transforms
.
BLC2CBL
()
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_compose
(
self
):
def
test_compose
(
self
):
...
@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
...
@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
repr_test
=
transforms
.
Compose
(
tset
)
repr_test
=
transforms
.
Compose
(
tset
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
...
@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
...
@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
.
__repr__
()
self
.
assertTrue
(
repr_test
.
__repr__
()
)
def
test_mel2
(
self
):
def
test_mel2
(
self
):
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
spectrogram_torch
=
transforms
.
MEL2
(
window_fn
=
torch
.
hamming_window
,
pad
=
10
)(
audio_scaled
)
# (1, 319, 40)
mel_transform
=
transforms
.
MEL2
(
window
=
torch
.
hamming_window
,
pad
=
10
)
spectrogram_torch
=
mel_transform
(
audio_scaled
)
# (1, 319, 40)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
max
()
<=
0.
)
self
.
assertTrue
(
spectrogram_torch
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
# load stereo file
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
spectrogram_stereo
=
mel_transform
(
x_stereo
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_stereo
.
le
(
0.
).
all
())
self
.
assertTrue
(
spectrogram_stereo
.
ge
(
mel_transform
.
top_db
).
all
())
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/transforms.py
View file @
a8d6a41b
...
@@ -60,7 +60,7 @@ class Scale(object):
...
@@ -60,7 +60,7 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
"""
if
not
tensor
.
is_floating_point
()
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
tensor
.
to
(
torch
.
float32
)
return
tensor
/
self
.
factor
return
tensor
/
self
.
factor
...
@@ -125,7 +125,7 @@ class DownmixMono(object):
...
@@ -125,7 +125,7 @@ class DownmixMono(object):
self
.
ch_dim
=
int
(
not
channels_first
)
self
.
ch_dim
=
int
(
not
channels_first
)
def
__call__
(
self
,
tensor
):
def
__call__
(
self
,
tensor
):
if
not
tensor
.
is_floating_point
()
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
torch
.
mean
(
tensor
,
self
.
ch_dim
,
True
)
tensor
=
torch
.
mean
(
tensor
,
self
.
ch_dim
,
True
)
...
@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
...
@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
"""
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
window
_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
pad
=
0
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
_fn
(
ws
)
if
wkwargs
is
None
else
window
_fn
(
ws
,
**
wkwargs
)
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
sr
=
sr
self
.
sr
=
sr
self
.
ws
=
ws
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
...
@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
...
@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
if
self
.
pad
>
0
:
if
self
.
pad
>
0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
sig
=
torch
.
nn
.
functional
.
pad
(
sig
,
(
self
.
pad
,
self
.
pad
),
"constant"
)
sig
=
torch
.
nn
.
functional
.
pad
(
sig
,
(
self
.
pad
,
self
.
pad
),
"constant"
)
spec_f
=
torch
.
stft
(
sig
,
self
.
n_fft
,
self
.
hop
,
self
.
ws
,
spec_f
=
torch
.
stft
(
sig
,
self
.
n_fft
,
self
.
hop
,
self
.
ws
,
self
.
window
,
center
=
False
,
self
.
window
,
center
=
False
,
normalized
=
True
,
onesided
=
True
).
transpose
(
1
,
2
)
normalized
=
True
,
onesided
=
True
).
transpose
(
1
,
2
)
...
@@ -215,16 +214,28 @@ class F2M(object):
...
@@ -215,16 +214,28 @@ class F2M(object):
sr (int): sample rate of audio signal
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
"""
"""
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
):
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_fft
=
None
):
self
.
n_mels
=
n_mels
self
.
n_mels
=
n_mels
self
.
sr
=
sr
self
.
sr
=
sr
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_min
=
f_min
self
.
f_min
=
f_min
self
.
fb
=
self
.
_create_fb_matrix
(
n_fft
)
if
n_fft
is
not
None
else
n_fft
def
__call__
(
self
,
spec_f
):
def
__call__
(
self
,
spec_f
):
if
self
.
fb
is
None
:
self
.
fb
=
self
.
_create_fb_matrix
(
spec_f
.
size
(
2
))
spec_m
=
torch
.
matmul
(
spec_f
,
self
.
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
def
_create_fb_matrix
(
self
,
n_fft
):
""" Create a frequency bin conversion matrix.
n_fft
=
spec_f
.
size
(
2
)
Args:
n_fft (int): number of filter banks from spectrogram
"""
m_min
=
0.
if
self
.
f_min
==
0
else
2595
*
np
.
log10
(
1.
+
(
self
.
f_min
/
700
))
m_min
=
0.
if
self
.
f_min
==
0
else
2595
*
np
.
log10
(
1.
+
(
self
.
f_min
/
700
))
m_max
=
2595
*
np
.
log10
(
1.
+
(
self
.
f_max
/
700
))
m_max
=
2595
*
np
.
log10
(
1.
+
(
self
.
f_max
/
700
))
...
@@ -234,19 +245,12 @@ class F2M(object):
...
@@ -234,19 +245,12 @@ class F2M(object):
bins
=
torch
.
floor
(((
n_fft
-
1
)
*
2
)
*
f_pts
/
self
.
sr
).
long
()
bins
=
torch
.
floor
(((
n_fft
-
1
)
*
2
)
*
f_pts
/
self
.
sr
).
long
()
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
)
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
,
dtype
=
torch
.
float
)
for
m
in
range
(
1
,
self
.
n_mels
+
1
):
for
m
in
range
(
1
,
self
.
n_mels
+
1
):
f_m_minus
=
bins
[
m
-
1
].
item
()
f_m_minus
=
bins
[
m
-
1
].
item
()
f_m
=
bins
[
m
].
item
()
f_m_plus
=
bins
[
m
+
1
].
item
()
f_m_plus
=
bins
[
m
+
1
].
item
()
fb
[
f_m_minus
:
f_m_plus
,
m
-
1
]
=
torch
.
bartlett_window
(
f_m_plus
-
f_m_minus
)
if
f_m_minus
!=
f_m
:
return
fb
fb
[
f_m_minus
:
f_m
,
m
-
1
]
=
(
torch
.
arange
(
f_m_minus
,
f_m
)
-
f_m_minus
)
/
(
f_m
-
f_m_minus
)
if
f_m
!=
f_m_plus
:
fb
[
f_m
:
f_m_plus
,
m
-
1
]
=
(
f_m_plus
-
torch
.
arange
(
f_m
,
f_m_plus
))
/
(
f_m_plus
-
f_m
)
spec_m
=
torch
.
matmul
(
spec_f
,
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
class
SPEC2DB
(
object
):
class
SPEC2DB
(
object
):
...
@@ -267,7 +271,7 @@ class SPEC2DB(object):
...
@@ -267,7 +271,7 @@ class SPEC2DB(object):
spec_db
=
self
.
multiplier
*
torch
.
log10
(
spec
/
spec
.
max
())
# power -> dB
spec_db
=
self
.
multiplier
*
torch
.
log10
(
spec
/
spec
.
max
())
# power -> dB
if
self
.
top_db
is
not
None
:
if
self
.
top_db
is
not
None
:
spec_db
=
torch
.
max
(
spec_db
,
spec_db
.
new
([
self
.
top_db
]
))
spec_db
=
torch
.
max
(
spec_db
,
torch
.
tensor
(
self
.
top_db
,
dtype
=
spec_db
.
dtype
))
return
spec_db
return
spec_db
...
@@ -296,8 +300,8 @@ class MEL2(object):
...
@@ -296,8 +300,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
n_mels
=
40
,
window
_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
pad
=
0
,
n_mels
=
40
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
_fn
=
window
_fn
self
.
window
=
window
self
.
sr
=
sr
self
.
sr
=
sr
self
.
ws
=
ws
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
...
@@ -308,6 +312,13 @@ class MEL2(object):
...
@@ -308,6 +312,13 @@ class MEL2(object):
self
.
top_db
=
-
80.
self
.
top_db
=
-
80.
self
.
f_max
=
None
self
.
f_max
=
None
self
.
f_min
=
0.
self
.
f_min
=
0.
self
.
spec
=
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window
,
self
.
wkwargs
)
self
.
fm
=
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
,
self
.
n_fft
)
self
.
s2db
=
SPEC2DB
(
"power"
,
self
.
top_db
)
self
.
transforms
=
Compose
([
self
.
spec
,
self
.
fm
,
self
.
s2db
,
])
def
__call__
(
self
,
sig
):
def
__call__
(
self
,
sig
):
"""
"""
...
@@ -320,15 +331,7 @@ class MEL2(object):
...
@@ -320,15 +331,7 @@ class MEL2(object):
number of mel bins.
number of mel bins.
"""
"""
spec_mel_db
=
self
.
transforms
(
sig
)
transforms
=
Compose
([
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window_fn
,
self
.
wkwargs
),
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
),
SPEC2DB
(
"power"
,
self
.
top_db
),
])
spec_mel_db
=
transforms
(
sig
)
return
spec_mel_db
return
spec_mel_db
...
@@ -426,7 +429,7 @@ class MuLawEncoding(object):
...
@@ -426,7 +429,7 @@ class MuLawEncoding(object):
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
astype
(
int
)
x_mu
=
((
x_mu
+
1
)
/
2
*
mu
+
0.5
).
astype
(
int
)
elif
isinstance
(
x
,
torch
.
Tensor
):
elif
isinstance
(
x
,
torch
.
Tensor
):
if
not
x
.
is_floating_point
()
:
if
not
x
.
dtype
.
is_floating_point
:
x
=
x
.
to
(
torch
.
float
)
x
=
x
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
x_mu
=
torch
.
sign
(
x
)
*
torch
.
log1p
(
mu
*
...
@@ -468,7 +471,7 @@ class MuLawExpanding(object):
...
@@ -468,7 +471,7 @@ class MuLawExpanding(object):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
np
.
sign
(
x
)
*
(
np
.
exp
(
np
.
abs
(
x
)
*
np
.
log1p
(
mu
))
-
1.
)
/
mu
x
=
np
.
sign
(
x
)
*
(
np
.
exp
(
np
.
abs
(
x
)
*
np
.
log1p
(
mu
))
-
1.
)
/
mu
elif
isinstance
(
x_mu
,
torch
.
Tensor
):
elif
isinstance
(
x_mu
,
torch
.
Tensor
):
if
not
x_mu
.
is_floating_point
()
:
if
not
x_mu
.
dtype
.
is_floating_point
:
x_mu
=
x_mu
.
to
(
torch
.
float
)
x_mu
=
x_mu
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment