Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
78be73b7
Commit
78be73b7
authored
Jan 15, 2018
by
David Pollack
Committed by
Soumith Chintala
Apr 25, 2018
Browse files
mel spectrograms in pytorch (no longer req librosa)
parent
c844ac63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
7 deletions
+219
-7
test/test_transforms.py
test/test_transforms.py
+12
-4
torchaudio/transforms.py
torchaudio/transforms.py
+207
-3
No files found.
test/test_transforms.py
View file @
78be73b7
...
@@ -10,8 +10,9 @@ class Tester(unittest.TestCase):
...
@@ -10,8 +10,9 @@ class Tester(unittest.TestCase):
sr
=
16000
sr
=
16000
freq
=
440
freq
=
440
volume
=
0
.3
volume
=
.
3
sig
=
(
torch
.
cos
(
2
*
np
.
pi
*
torch
.
arange
(
0
,
4
*
sr
)
*
freq
/
sr
)).
float
()
sig
=
(
torch
.
cos
(
2
*
np
.
pi
*
torch
.
arange
(
0
,
4
*
sr
)
*
freq
/
sr
)).
float
()
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig
.
unsqueeze_
(
1
)
sig
.
unsqueeze_
(
1
)
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
...
@@ -86,11 +87,11 @@ class Tester(unittest.TestCase):
...
@@ -86,11 +87,11 @@ class Tester(unittest.TestCase):
audio
=
self
.
sig
.
clone
()
audio
=
self
.
sig
.
clone
()
audio
=
transforms
.
Scale
()(
audio
)
audio
=
transforms
.
Scale
()(
audio
)
self
.
assertTrue
(
len
(
audio
.
size
()
)
==
2
)
self
.
assertTrue
(
audio
.
dim
(
)
==
2
)
result
=
transforms
.
MEL
()(
audio
)
result
=
transforms
.
MEL
()(
audio
)
self
.
assertTrue
(
len
(
result
.
size
()
)
==
3
)
self
.
assertTrue
(
result
.
dim
(
)
==
3
)
result
=
transforms
.
BLC2CBL
()(
result
)
result
=
transforms
.
BLC2CBL
()(
result
)
self
.
assertTrue
(
len
(
result
.
size
()
)
==
3
)
self
.
assertTrue
(
result
.
dim
(
)
==
3
)
repr_test
=
transforms
.
MEL
()
repr_test
=
transforms
.
MEL
()
repr_test
.
__repr__
()
repr_test
.
__repr__
()
...
@@ -146,6 +147,13 @@ class Tester(unittest.TestCase):
...
@@ -146,6 +147,13 @@ class Tester(unittest.TestCase):
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
repr_test
.
__repr__
()
repr_test
.
__repr__
()
def
test_mel2
(
self
):
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
spectrogram_torch
=
transforms
.
MEL2
()(
audio_scaled
)
# (1, 319, 40)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
max
()
<=
0.
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/transforms.py
View file @
78be73b7
from
__future__
import
division
,
print_function
from
__future__
import
division
,
print_function
import
torch
import
torch
from
torch.autograd
import
Variable
import
numpy
as
np
import
numpy
as
np
try
:
try
:
import
librosa
import
librosa
...
@@ -7,6 +8,24 @@ except ImportError:
...
@@ -7,6 +8,24 @@ except ImportError:
librosa
=
None
librosa
=
None
def
_check_is_variable
(
tensor
):
if
isinstance
(
tensor
,
torch
.
Tensor
):
is_variable
=
False
tensor
=
Variable
(
tensor
,
requires_grad
=
False
)
elif
isinstance
(
tensor
,
Variable
):
is_variable
=
True
else
:
raise
TypeError
(
"tensor should be a Variable or Tensor, but is {}"
.
format
(
type
(
tensor
)))
return
tensor
,
is_variable
def
_tlog10
(
x
):
"""Pytorch Log10
"""
return
torch
.
log
(
x
)
/
torch
.
log
(
x
.
new
([
10
]))
class
Compose
(
object
):
class
Compose
(
object
):
"""Composes several transforms together.
"""Composes several transforms together.
...
@@ -137,10 +156,10 @@ class LC2CL(object):
...
@@ -137,10 +156,10 @@ class LC2CL(object):
"""
"""
Args:
Args:
tensor (Tensor): Tensor of
spectrogram
with shape (
Bx
LxC)
tensor (Tensor): Tensor of
audio signal
with shape (LxC)
Returns:
Returns:
tensor (Tensor): Tensor of
spectrogram
with shape (Cx
Bx
L)
tensor (Tensor): Tensor of
audio signal
with shape (CxL)
"""
"""
...
@@ -150,6 +169,190 @@ class LC2CL(object):
...
@@ -150,6 +169,190 @@ class LC2CL(object):
return
self
.
__class__
.
__name__
+
'()'
return
self
.
__class__
.
__name__
+
'()'
class
SPECTROGRAM
(
object
):
"""Create a spectrogram from a raw audio signal
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
if
isinstance
(
window
,
Variable
):
self
.
window
=
window
else
:
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
window
=
Variable
(
self
.
window
,
volatile
=
True
)
self
.
sr
=
sr
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
self
.
n_fft
=
n_fft
# number of fft bins
self
.
pad
=
pad
self
.
wkwargs
=
wkwargs
def
__call__
(
self
,
sig
):
"""
Args:
sig (Tensor or Variable): Tensor of audio of size (c, n)
Returns:
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
"""
sig
,
is_variable
=
_check_is_variable
(
sig
)
assert
sig
.
dim
()
==
2
spec_f
=
torch
.
stft
(
sig
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
True
,
self
.
window
,
self
.
pad
)
# (c, l, n_fft, 2)
spec_f
/=
self
.
window
.
pow
(
2
).
sum
().
sqrt
()
spec_f
=
spec_f
.
pow
(
2
).
sum
(
-
1
)
# get power of "complex" tensor (c, l, n_fft)
return
spec_f
if
is_variable
else
spec_f
.
data
class
F2M
(
object
):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Args:
n_mels (int): number of MEL bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
"""
def
__init__
(
self
,
n_mels
=
40
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
):
self
.
n_mels
=
n_mels
self
.
sr
=
sr
self
.
f_max
=
f_max
if
f_max
is
not
None
else
sr
//
2
self
.
f_min
=
f_min
def
__call__
(
self
,
spec_f
):
spec_f
,
is_variable
=
_check_is_variable
(
spec_f
)
n_fft
=
spec_f
.
size
(
2
)
m_min
=
0.
if
self
.
f_min
==
0
else
2595
*
np
.
log10
(
1.
+
(
self
.
f_min
/
700
))
m_max
=
2595
*
np
.
log10
(
1.
+
(
self
.
f_max
/
700
))
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
self
.
n_mels
+
2
)
f_pts
=
(
700
*
(
10
**
(
m_pts
/
2595
)
-
1
))
bins
=
torch
.
floor
(((
n_fft
-
1
)
*
2
)
*
f_pts
/
self
.
sr
).
long
()
fb
=
torch
.
zeros
(
n_fft
,
self
.
n_mels
)
for
m
in
range
(
1
,
self
.
n_mels
+
1
):
f_m_minus
=
bins
[
m
-
1
]
f_m
=
bins
[
m
]
f_m_plus
=
bins
[
m
+
1
]
if
f_m_minus
!=
f_m
:
fb
[
f_m_minus
:
f_m
,
m
-
1
]
=
(
torch
.
arange
(
f_m_minus
,
f_m
)
-
f_m_minus
)
/
(
f_m
-
f_m_minus
)
if
f_m
!=
f_m_plus
:
fb
[
f_m
:
f_m_plus
,
m
-
1
]
=
(
f_m_plus
-
torch
.
arange
(
f_m
,
f_m_plus
))
/
(
f_m_plus
-
f_m
)
fb
=
Variable
(
fb
)
spec_m
=
torch
.
matmul
(
spec_f
,
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
if
is_variable
else
spec_m
.
data
class
SPEC2DB
(
object
):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is -80.
"""
def
__init__
(
self
,
stype
=
"power"
,
top_db
=
None
):
self
.
stype
=
stype
self
.
top_db
=
-
top_db
if
top_db
>
0
else
top_db
self
.
multiplier
=
10.
if
stype
==
"power"
else
20.
def
__call__
(
self
,
spec
):
spec
,
is_variable
=
_check_is_variable
(
spec
)
spec_db
=
self
.
multiplier
*
_tlog10
(
spec
/
spec
.
max
())
# power -> dB
if
self
.
top_db
is
not
None
:
spec_db
=
torch
.
max
(
spec_db
,
spec_db
.
new
([
self
.
top_db
]))
return
spec_db
if
is_variable
else
spec_db
.
data
class
MEL2
(
object
):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using
librosa.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def
__init__
(
self
,
sr
=
16000
,
ws
=
400
,
hop
=
None
,
n_fft
=
None
,
pad
=
0
,
n_mels
=
40
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
self
.
window
=
window
(
ws
)
if
wkwargs
is
None
else
window
(
ws
,
**
wkwargs
)
self
.
window
=
Variable
(
self
.
window
,
requires_grad
=
False
)
self
.
sr
=
sr
self
.
ws
=
ws
self
.
hop
=
hop
if
hop
is
not
None
else
ws
//
2
self
.
n_fft
=
n_fft
# number of fourier bins (ws // 2 + 1 by default)
self
.
pad
=
pad
self
.
n_mels
=
n_mels
# number of mel frequency bins
self
.
wkwargs
=
wkwargs
self
.
top_db
=
-
80.
self
.
f_max
=
None
self
.
f_min
=
0.
def
__call__
(
self
,
sig
):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
sig
,
is_variable
=
_check_is_variable
(
sig
)
transforms
=
Compose
([
SPECTROGRAM
(
self
.
sr
,
self
.
ws
,
self
.
hop
,
self
.
n_fft
,
self
.
pad
,
self
.
window
),
F2M
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
),
SPEC2DB
(
"power"
,
self
.
top_db
),
])
spec_mel_db
=
transforms
(
sig
)
return
spec_mel_db
if
is_variable
else
spec_mel_db
.
data
class
MEL
(
object
):
class
MEL
(
object
):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
...
@@ -164,7 +367,7 @@ class MEL(object):
...
@@ -164,7 +367,7 @@ class MEL(object):
"""
"""
Args:
Args:
tensor (Tensor): Tensor of audio of size (samples x channels)
tensor (Tensor): Tensor of audio of size (samples
[n]
x channels
[c]
)
Returns:
Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
...
@@ -172,6 +375,7 @@ class MEL(object):
...
@@ -172,6 +375,7 @@ class MEL(object):
is unchanged.
is unchanged.
"""
"""
if
librosa
is
None
:
if
librosa
is
None
:
print
(
"librosa not installed, cannot create spectrograms"
)
print
(
"librosa not installed, cannot create spectrograms"
)
return
tensor
return
tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment