Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b29a4639
Commit
b29a4639
authored
Jul 24, 2019
by
jamarshon
Committed by
cpuhrsch
Jul 24, 2019
Browse files
[BC] Standardization of Transforms/Functionals (#152)
parent
2271a7ae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
365 additions
and
839 deletions
+365
-839
test/test_functional.py
test/test_functional.py
+2
-52
test/test_jit.py
test/test_jit.py
+12
-92
test/test_transforms.py
test/test_transforms.py
+74
-142
torchaudio/functional.py
torchaudio/functional.py
+115
-240
torchaudio/transforms.py
torchaudio/transforms.py
+162
-313
No files found.
test/test_functional.py
View file @
b29a4639
...
...
@@ -2,6 +2,8 @@ import math
import
torch
import
torchaudio
import
torchaudio.functional
as
F
import
pytest
import
unittest
import
test.common_utils
...
...
@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
import
numpy
as
np
import
librosa
import
pytest
import
torchaudio.functional
as
F
xfail
=
pytest
.
mark
.
xfail
class
TestFunctional
(
unittest
.
TestCase
):
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
...
...
@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
@
pytest
.
mark
.
parametrize
(
'fft_length'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'hop_length'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'waveform'
,
[
(
torch
.
randn
(
1
,
100000
)),
(
torch
.
randn
(
1
,
2
,
100000
)),
pytest
.
param
(
torch
.
randn
(
1
,
100
),
marks
=
xfail
(
raises
=
RuntimeError
)),
])
@
pytest
.
mark
.
parametrize
(
'pad_mode'
,
[
# 'constant',
'reflect'
,
])
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa is not available'
)
def
test_stft
(
waveform
,
fft_length
,
hop_length
,
pad_mode
):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad
=
fft_length
//
2
window
=
torch
.
hann_window
(
fft_length
)
complex_spec
=
F
.
stft
(
waveform
,
fft_length
=
fft_length
,
hop_length
=
hop_length
,
window
=
window
,
pad_mode
=
pad_mode
)
mag_spec
,
phase_spec
=
F
.
magphase
(
complex_spec
)
# == Test shape
expected_size
=
list
(
waveform
.
size
()[:
-
1
])
expected_size
+=
[
fft_length
//
2
+
1
,
_num_stft_bins
(
waveform
.
size
(
-
1
),
fft_length
,
hop_length
,
pad
),
2
]
assert
complex_spec
.
dim
()
==
waveform
.
dim
()
+
2
assert
complex_spec
.
size
()
==
torch
.
Size
(
expected_size
)
# == Test values
fft_config
=
dict
(
n_fft
=
fft_length
,
hop_length
=
hop_length
,
pad_mode
=
pad_mode
)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec
=
np
.
apply_along_axis
(
librosa
.
stft
,
-
1
,
waveform
.
numpy
(),
**
fft_config
)
expected_mag_spec
,
_
=
librosa
.
magphase
(
expected_complex_spec
)
# Convert torch to np.complex
complex_spec
=
complex_spec
.
numpy
()
complex_spec
=
complex_spec
[...,
0
]
+
1j
*
complex_spec
[...,
1
]
assert
np
.
allclose
(
complex_spec
,
expected_complex_spec
,
atol
=
1e-5
)
assert
np
.
allclose
(
mag_spec
.
numpy
(),
expected_mag_spec
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
'rate'
,
[
0.5
,
1.01
,
1.3
])
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
),
...
...
test/test_jit.py
View file @
b29a4639
...
...
@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
def
test_torchscript_scale
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
factor
):
# type: (Tensor, int) -> Tensor
return
F
.
scale
(
tensor
,
factor
)
tensor
=
torch
.
rand
((
10
,
1
))
factor
=
2
jit_out
=
jit_method
(
tensor
,
factor
)
py_out
=
F
.
scale
(
tensor
,
factor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_scale
(
self
):
tensor
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
Scale
)
def
test_torchscript_pad_trim
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
):
# type: (Tensor, int,
int, int,
float) -> Tensor
return
F
.
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
def
jit_method
(
tensor
,
max_len
,
fill_value
):
# type: (Tensor, int, float) -> Tensor
return
F
.
pad_trim
(
tensor
,
max_len
,
fill_value
)
tensor
=
torch
.
rand
((
10
,
1
))
ch_dim
=
1
tensor
=
torch
.
rand
((
1
,
10
))
max_len
=
5
len_dim
=
0
fill_value
=
3.
jit_out
=
jit_method
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
py_out
=
F
.
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
jit_out
=
jit_method
(
tensor
,
max_len
,
fill_value
)
py_out
=
F
.
pad_trim
(
tensor
,
max_len
,
fill_value
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
...
...
@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
self
.
_test_script_module
(
tensor
,
transforms
.
PadTrim
,
max_len
)
def
test_torchscript_downmix_mono
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
ch_dim
):
# type: (Tensor, int) -> Tensor
return
F
.
downmix_mono
(
tensor
,
ch_dim
)
tensor
=
torch
.
rand
((
10
,
1
))
ch_dim
=
1
jit_out
=
jit_method
(
tensor
,
ch_dim
)
py_out
=
F
.
downmix_mono
(
tensor
,
ch_dim
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_downmix_mono
(
self
):
tensor
=
torch
.
rand
((
1
,
10
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
DownmixMono
)
def
test_torchscript_LC2CL
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
):
# type: (Tensor) -> Tensor
return
F
.
LC2CL
(
tensor
)
tensor
=
torch
.
rand
((
10
,
1
))
jit_out
=
jit_method
(
tensor
)
py_out
=
F
.
LC2CL
(
tensor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_LC2CL
(
self
):
tensor
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
LC2CL
)
def
test_torchscript_spectrogram
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
...
...
@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return
F
.
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
spec
=
torch
.
rand
((
10
,
1
))
spec
=
torch
.
rand
((
6
,
20
1
))
multiplier
=
10.
amin
=
1e-10
db_multiplier
=
0.
...
...
@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_SpectrogramToDB
(
self
):
spec
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
spec
=
torch
.
rand
((
6
,
20
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
spec
,
transforms
.
SpectrogramToDB
)
...
...
@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
self
.
_test_script_module
(
tensor
,
transforms
.
MelSpectrogram
)
def
test_torchscript_BLC2CBL
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
):
# type: (Tensor) -> Tensor
return
F
.
BLC2CBL
(
tensor
)
tensor
=
torch
.
rand
((
10
,
1000
,
1
))
jit_out
=
jit_method
(
tensor
)
py_out
=
F
.
BLC2CBL
(
tensor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_BLC2CBL
(
self
):
tensor
=
torch
.
rand
((
10
,
1000
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
BLC2CBL
)
def
test_torchscript_mu_law_encoding
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
qc
):
# type: (Tensor, int) -> Tensor
return
F
.
mu_law_encoding
(
tensor
,
qc
)
tensor
=
torch
.
rand
((
1
0
,
1
))
tensor
=
torch
.
rand
((
1
,
1
0
))
qc
=
256
jit_out
=
jit_method
(
tensor
,
qc
)
...
...
@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
0
,
1
),
device
=
"cuda"
)
tensor
=
torch
.
rand
((
1
,
1
0
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawEncoding
)
...
...
@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, int) -> Tensor
return
F
.
mu_law_expanding
(
tensor
,
qc
)
tensor
=
torch
.
rand
((
1
0
,
1
))
tensor
=
torch
.
rand
((
1
,
1
0
))
qc
=
256
jit_out
=
jit_method
(
tensor
,
qc
)
...
...
@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_MuLawExpanding
(
self
):
tensor
=
torch
.
rand
((
1
0
,
1
),
device
=
"cuda"
)
tensor
=
torch
.
rand
((
1
,
1
0
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawExpanding
)
...
...
test/test_transforms.py
View file @
b29a4639
...
...
@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
s
r
=
16000
s
ample_rate
=
16000
freq
=
440
volume
=
.
3
sig
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
s
r
).
float
()
*
freq
/
s
r
))
sig
.
unsqueeze_
(
1
)
# (64000
, 1
)
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
waveform
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
s
ample_rate
).
float
()
*
freq
/
s
ample_rate
))
waveform
.
unsqueeze_
(
0
)
# (
1,
64000)
waveform
=
(
waveform
*
volume
*
2
**
31
).
long
()
# file for stereo stft test
test_dirpath
,
test_dir
=
test
.
common_utils
.
create_temp_assets_dir
()
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"
assets
"
,
"
steam-train-whistle-daniel_simon.mp3
"
)
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
'
assets
'
,
'
steam-train-whistle-daniel_simon.mp3
'
)
def
test_scale
(
self
):
audio_orig
=
self
.
sig
.
clone
()
result
=
transforms
.
Scale
()(
audio_orig
)
self
.
assertTrue
(
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
maxminmax
=
max
(
abs
(
audio_orig
.
min
()),
abs
(
audio_orig
.
max
())).
item
()
result
=
transforms
.
Scale
(
factor
=
maxminmax
)(
audio_orig
)
self
.
assertTrue
((
result
.
min
()
==
-
1.
or
result
.
max
()
==
1.
)
and
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
repr_test
=
transforms
.
Scale
()
self
.
assertTrue
(
repr_test
.
__repr__
())
def
scale
(
self
,
waveform
,
factor
=
float
(
2
**
31
)):
# scales a waveform by a factor
if
not
waveform
.
is_floating_point
():
waveform
=
waveform
.
to
(
torch
.
get_default_dtype
())
return
waveform
/
factor
def
test_pad_trim
(
self
):
audio_orig
=
self
.
sig
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
waveform
=
self
.
waveform
.
clone
()
length_orig
=
waveform
.
size
(
1
)
length_new
=
int
(
length_orig
*
1.2
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)(
audio_orig
)
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
True
)(
audio_orig
.
transpose
(
0
,
1
))
result
=
transforms
.
PadTrim
(
max_len
=
length_new
)(
waveform
)
self
.
assertEqual
(
result
.
size
(
1
),
length_new
)
audio_orig
=
self
.
sig
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
length_new
=
int
(
length_orig
*
0.8
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)(
audio_orig
)
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_downmix_mono
(
self
):
audio_L
=
self
.
sig
.
clone
()
audio_R
=
self
.
sig
.
clone
()
R_idx
=
int
(
audio_R
.
size
(
0
)
*
0.1
)
audio_R
=
torch
.
cat
((
audio_R
[
R_idx
:],
audio_R
[:
R_idx
]))
audio_Stereo
=
torch
.
cat
((
audio_L
,
audio_R
),
dim
=
1
)
self
.
assertTrue
(
audio_Stereo
.
size
(
1
)
==
2
)
result
=
transforms
.
DownmixMono
(
channels_first
=
False
)(
audio_Stereo
)
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
repr_test
=
transforms
.
DownmixMono
(
channels_first
=
False
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_lc2cl
(
self
):
audio
=
self
.
sig
.
clone
()
result
=
transforms
.
LC2CL
()(
audio
)
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
repr_test
=
transforms
.
LC2CL
()
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_compose
(
self
):
audio_orig
=
self
.
sig
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
length_new
=
int
(
length_orig
*
1.2
)
maxminmax
=
max
(
abs
(
audio_orig
.
min
()),
abs
(
audio_orig
.
max
())).
item
()
tset
=
(
transforms
.
Scale
(
factor
=
maxminmax
),
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
))
result
=
transforms
.
Compose
(
tset
)(
audio_orig
)
self
.
assertTrue
(
max
(
abs
(
result
.
min
()),
abs
(
result
.
max
()))
==
1.
)
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
repr_test
=
transforms
.
Compose
(
tset
)
self
.
assertTrue
(
repr_test
.
__repr__
())
result
=
transforms
.
PadTrim
(
max_len
=
length_new
)(
waveform
)
self
.
assertEqual
(
result
.
size
(
1
),
length_new
)
def
test_mu_law_companding
(
self
):
quantization_channels
=
256
sig
=
self
.
sig
.
clone
()
sig
=
sig
/
torch
.
abs
(
sig
).
max
()
self
.
assertTrue
(
sig
.
min
()
>=
-
1.
and
sig
.
max
()
<=
1.
)
sig_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
sig
)
self
.
assertTrue
(
sig_mu
.
min
()
>=
0.
and
sig
.
max
()
<=
quantization_channels
)
waveform
=
self
.
waveform
.
clone
()
waveform
/=
torch
.
abs
(
waveform
).
max
()
self
.
assertTrue
(
waveform
.
min
()
>=
-
1.
and
waveform
.
max
()
<=
1.
)
sig_exp
=
transforms
.
MuLawE
xpan
ding
(
quantization_channels
)(
sig_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1
.
and
sig_exp
.
max
()
<=
1.
)
waveform_mu
=
transforms
.
MuLawE
nco
ding
(
quantization_channels
)(
waveform
)
self
.
assertTrue
(
waveform_mu
.
min
()
>=
0
.
and
waveform_mu
.
max
()
<=
quantization_channels
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
self
.
assertTrue
(
repr_test
.
__repr__
())
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
self
.
assertTrue
(
repr_test
.
__repr__
())
waveform_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
waveform_mu
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
def
test_mel2
(
self
):
top_db
=
80.
s2db
=
transforms
.
SpectrogramToDB
(
"
power
"
,
top_db
)
s2db
=
transforms
.
SpectrogramToDB
(
'
power
'
,
top_db
)
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
waveform
=
self
.
waveform
.
clone
()
# (1, 16000)
waveform_scaled
=
self
.
scale
(
waveform
)
# (1, 16000)
mel_transform
=
transforms
.
MelSpectrogram
()
# check defaults
spectrogram_torch
=
s2db
(
mel_transform
(
audio
_scaled
))
# (1,
319, 40
)
spectrogram_torch
=
s2db
(
mel_transform
(
waveform
_scaled
))
# (1,
128, 321
)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_torch
.
size
(
1
),
mel_transform
.
n_mels
)
# check correctness of filterbank conversion matrix
self
.
assertTrue
(
mel_transform
.
f
m
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform
.
f
m
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
mel_transform
.
m
el_scale
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform
.
m
el_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check options
kwargs
=
{
"window"
:
torch
.
hamming_window
,
"pad"
:
10
,
"ws"
:
500
,
"hop"
:
125
,
"n_fft"
:
800
,
"n_mels"
:
50
}
kwargs
=
{
'window_fn'
:
torch
.
hamming_window
,
'pad'
:
10
,
'win_length'
:
500
,
'hop_length'
:
125
,
'n_fft'
:
800
,
'n_mels'
:
50
}
mel_transform2
=
transforms
.
MelSpectrogram
(
**
kwargs
)
spectrogram2_torch
=
s2db
(
mel_transform2
(
audio
_scaled
))
# (1, 50
6
, 5
0
)
spectrogram2_torch
=
s2db
(
mel_transform2
(
waveform
_scaled
))
# (1, 50, 5
13
)
self
.
assertTrue
(
spectrogram2_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram2_torch
.
size
(
-
1
),
mel_transform2
.
n_mels
)
self
.
assertTrue
(
mel_transform2
.
f
m
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform2
.
f
m
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
spectrogram2_torch
.
size
(
1
),
mel_transform2
.
n_mels
)
self
.
assertTrue
(
mel_transform2
.
m
el_scale
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform2
.
m
el_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check on multi-channel audio
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
# (2, 128, 1394)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_stereo
.
size
(
1
),
mel_transform
.
n_mels
)
# check filterbank matrix creation
fb_matrix_transform
=
transforms
.
MelScale
(
n_mels
=
100
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_stft
=
400
)
fb_matrix_transform
=
transforms
.
MelScale
(
n_mels
=
100
,
sample_rate
=
16000
,
f_min
=
0.
,
f_max
=
None
,
n_stft
=
400
)
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
def
test_mfcc
(
self
):
audio_orig
=
self
.
sig
.
clone
()
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
audio_orig
=
self
.
waveform
.
clone
()
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
sample_rate
=
16000
n_mfcc
=
40
n_mels
=
128
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
)
# check defaults
torch_mfcc
=
mfcc_transform
(
audio_scaled
)
torch_mfcc
=
mfcc_transform
(
audio_scaled
)
# (1, 40, 321)
self
.
assertTrue
(
torch_mfcc
.
dim
()
==
3
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
2
]
==
n_mfcc
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
1
]
==
321
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
1
]
==
n_mfcc
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
2
]
==
321
)
# check melkwargs are passed through
melkwargs
=
{
'w
s
'
:
200
}
mfcc_transform2
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
melkwargs
=
{
'w
in_length
'
:
200
}
mfcc_transform2
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
,
melkwargs
=
melkwargs
)
torch_mfcc2
=
mfcc_transform2
(
audio_scaled
)
self
.
assertTrue
(
torch_mfcc2
.
shape
[
1
]
==
641
)
torch_mfcc2
=
mfcc_transform2
(
audio_scaled
)
# (1, 40, 641)
self
.
assertTrue
(
torch_mfcc2
.
shape
[
2
]
==
641
)
# check norms work correctly
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
norm
=
None
)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio_scaled
)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio_scaled
)
# (1, 40, 321)
norm_check
=
torch_mfcc
.
clone
()
norm_check
[:,
:
,
0
]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
:,
1
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
norm_check
[:,
0
,
:
]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
1
:,
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
...
...
@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
def
_test_librosa_consistency_helper
(
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
sound
,
sample_rate
=
torchaudio
.
load
(
input_path
)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
.
T
#
squeeze batch and channel first
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
#
(64000)
# test core spectrogram
spect_transform
=
torchaudio
.
transforms
.
Spectrogram
(
n_fft
=
n_fft
,
hop
=
hop_length
,
power
=
2
)
spect_transform
=
torchaudio
.
transforms
.
Spectrogram
(
n_fft
=
n_fft
,
hop
_length
=
hop_length
,
power
=
2
)
out_librosa
,
_
=
librosa
.
core
.
spectrum
.
_spectrogram
(
y
=
sound_librosa
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
power
=
2
)
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
))
# test mel spectrogram
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
sr
=
sample_rate
,
window
=
torch
.
hann_window
,
hop
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
)
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
)
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
))
# test s2db
db_transform
=
torchaudio
.
transforms
.
SpectrogramToDB
(
"power"
,
80.
)
db_torch
=
db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
().
t
()
db_transform
=
torchaudio
.
transforms
.
SpectrogramToDB
(
'power'
,
80.
)
db_torch
=
db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
out_librosa
)
self
.
assertTrue
(
torch
.
allclose
(
db_torch
,
torch
.
from_numpy
(
db_librosa
),
atol
=
5e-3
))
db_torch
=
db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
.
t
()
db_torch
=
db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
librosa_mel
)
db_librosa_tensor
=
torch
.
from_numpy
(
db_librosa
)
self
.
assertTrue
(
torch
.
allclose
(
db_torch
.
type
(
db_librosa_tensor
.
dtype
),
db_librosa_tensor
,
atol
=
5e-3
))
# test MFCC
melkwargs
=
{
'hop'
:
hop_length
,
'n_fft'
:
n_fft
}
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
melkwargs
=
{
'hop
_length
'
:
hop_length
,
'n_fft'
:
n_fft
}
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
,
melkwargs
=
melkwargs
)
...
...
@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
librosa_mfcc
=
scipy
.
fftpack
.
dct
(
db_librosa
,
axis
=
0
,
type
=
2
,
norm
=
'ortho'
)[:
n_mfcc
]
librosa_mfcc_tensor
=
torch
.
from_numpy
(
librosa_mfcc
)
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
))
...
...
@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
def
test_resample_size
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
sound
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
upsample_rate
=
sample_rate
*
2
downsample_rate
=
sample_rate
//
2
invalid_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'foo'
)
self
.
assertRaises
(
ValueError
,
invalid_resample
,
sound
)
self
.
assertRaises
(
ValueError
,
invalid_resample
,
waveform
)
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
up_sampled
=
upsample_resample
(
sound
)
up_sampled
=
upsample_resample
(
waveform
)
# we expect the upsampled signal to have twice as many samples
self
.
assertTrue
(
up_sampled
.
size
(
-
1
)
==
sound
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
up_sampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
)
downsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
downsample_rate
,
resampling_method
=
'sinc_interpolation'
)
down_sampled
=
downsample_resample
(
sound
)
down_sampled
=
downsample_resample
(
waveform
)
# we expect the downsampled signal to have half as many samples
self
.
assertTrue
(
down_sampled
.
size
(
-
1
)
==
sound
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
down_sampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchaudio/functional.py
View file @
b29a4639
...
...
@@ -3,109 +3,48 @@ import torch
__all__
=
[
'scale'
,
'pad_trim'
,
'downmix_mono'
,
'LC2CL'
,
'istft'
,
'spectrogram'
,
'create_fb_matrix'
,
'spectrogram_to_DB'
,
'create_dct'
,
'BLC2CBL'
,
'mu_law_encoding'
,
'mu_law_expanding'
'mu_law_expanding'
,
'complex_norm'
,
'angle'
,
'magphase'
,
'phase_vocoder'
,
]
@
torch
.
jit
.
script
def
scale
(
tensor
,
factor
):
# type: (Tensor, int) -> Tensor
r
"""Scale audio tensor from a 16-bit integer (represented as a
:class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0.
Note the 16-bit number is called the "bit depth" or "precision", not to be
confused with "bit rate".
Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
factor (int): Maximum value of input tensor
Returns:
torch.Tensor: Scaled by the scale factor
"""
if
not
tensor
.
is_floating_point
():
tensor
=
tensor
.
to
(
torch
.
float32
)
return
tensor
/
factor
@
torch
.
jit
.
script
def
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
):
# type: (Tensor, int, int, int, float) -> Tensor
r
"""Pad/trim a 2D tensor (signal or labels).
def
pad_trim
(
waveform
,
max_len
,
fill_value
):
# type: (Tensor, int, float) -> Tensor
r
"""Pad/trim a 2D tensor
Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
ch_dim (int): Dimension of channel (not size)
max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size)
waveform (torch.Tensor): Tensor of audio of size (c, n)
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
Returns:
torch.Tensor: Padded/trimmed tensor
"""
if
max_len
>
tensor
.
size
(
len_dim
):
# array of [padding_left, padding_right, padding_top, padding_bottom]
# so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding.
padding
=
[
max_len
-
tensor
.
size
(
len_dim
)
if
(
i
%
2
==
1
)
and
(
i
//
2
!=
len_dim
)
else
0
for
i
in
[
0
,
1
,
2
,
3
]]
n
=
waveform
.
size
(
1
)
if
max_len
>
n
:
# TODO add "with torch.no_grad():" back when JIT supports it
tensor
=
torch
.
nn
.
functional
.
pad
(
tensor
,
padding
,
"constant"
,
fill_value
)
elif
max_len
<
tensor
.
size
(
len_dim
):
tensor
=
tensor
.
narrow
(
len_dim
,
0
,
max_len
)
return
tensor
@
torch
.
jit
.
script
def
downmix_mono
(
tensor
,
ch_dim
):
# type: (Tensor, int) -> Tensor
r
"""Downmix any stereo signals to mono.
Args:
tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c)
ch_dim (int): Dimension of channel (not size)
Returns:
torch.Tensor: Mono signal
"""
if
not
tensor
.
is_floating_point
():
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
torch
.
mean
(
tensor
,
ch_dim
,
True
)
return
tensor
@
torch
.
jit
.
script
def
LC2CL
(
tensor
):
# type: (Tensor) -> Tensor
r
"""Permute a 2D tensor from samples (n, c) to (c, n).
Args:
tensor (torch.Tensor): Tensor of audio signal with shape (n, c)
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
0
,
max_len
-
n
),
'constant'
,
fill_value
)
else
:
waveform
=
waveform
[:,
:
max_len
]
return
waveform
Returns:
torch.Tensor: Tensor of audio signal with shape (c, n)
"""
return
tensor
.
transpose
(
0
,
1
).
contiguous
()
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@
torch
.
jit
.
ignore
def
_stft
(
input
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
):
def
_stft
(
waveform
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return
torch
.
stft
(
input
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
)
return
torch
.
stft
(
waveform
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
)
def
istft
(
stft_matrix
,
# type: Tensor
...
...
@@ -149,8 +88,8 @@ def istft(stft_matrix, # type: Tensor
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a
batch
is a frequency and each
column is a window. it has a shape of either (
batch
, fft_size, n_frames, 2) or (
stft_matrix (torch.Tensor): Output of stft where each row of a
channel
is a frequency and each
column is a window. it has a shape of either (
channel
, fft_size, n_frames, 2) or (
fft_size, n_frames, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
...
...
@@ -168,20 +107,20 @@ def istft(stft_matrix, # type: Tensor
Returns:
torch.Tensor: Least squares estimation of the original signal of size
(
batch
, signal_length) or (signal_length)
(
channel
, signal_length) or (signal_length)
"""
stft_matrix_dim
=
stft_matrix
.
dim
()
assert
3
<=
stft_matrix_dim
<=
4
,
(
'Incorrect stft dimension: %d'
%
(
stft_matrix_dim
))
if
stft_matrix_dim
==
3
:
# add a
batch
dimension
# add a
channel
dimension
stft_matrix
=
stft_matrix
.
unsqueeze
(
0
)
device
=
stft_matrix
.
device
fft_size
=
stft_matrix
.
size
(
1
)
assert
(
onesided
and
n_fft
//
2
+
1
==
fft_size
)
or
(
not
onesided
and
n_fft
==
fft_size
),
(
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. '
+
'Given values were onesided: %s, n_fft: %d, fft_size: %d'
%
(
'True'
if
onesided
else
False
,
n_fft
,
fft_size
))
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. '
+
'Given values were onesided: %s, n_fft: %d, fft_size: %d'
%
(
'True'
if
onesided
else
False
,
n_fft
,
fft_size
))
# use stft defaults for Optionals
if
win_length
is
None
:
...
...
@@ -206,16 +145,16 @@ def istft(stft_matrix, # type: Tensor
assert
window
.
size
(
0
)
==
n_fft
# win_length and n_fft are synonymous from here on
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (
batch
, n_frames, fft_size, 2)
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (
channel
, n_frames, fft_size, 2)
stft_matrix
=
torch
.
irfft
(
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,))
# size (
batch
, n_frames, n_fft)
onesided
,
signal_sizes
=
(
n_fft
,))
# size (
channel
, n_frames, n_fft)
assert
stft_matrix
.
size
(
2
)
==
n_fft
n_frames
=
stft_matrix
.
size
(
1
)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (
batch
, n_frames, n_fft)
# each column of a
batch
is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (
batch
, n_fft, n_frames)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (
channel
, n_frames, n_fft)
# each column of a
channel
is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (
channel
, n_fft, n_frames)
eye
=
torch
.
eye
(
n_fft
,
requires_grad
=
False
,
device
=
device
).
unsqueeze
(
1
)
# size (n_fft, 1, n_fft)
...
...
@@ -223,7 +162,7 @@ def istft(stft_matrix, # type: Tensor
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y
=
torch
.
nn
.
functional
.
conv_transpose1d
(
ytmp
,
eye
,
stride
=
hop_length
,
padding
=
0
)
# size (
batch
, 1, expected_signal_len)
ytmp
,
eye
,
stride
=
hop_length
,
padding
=
0
)
# size (
channel
, 1, expected_signal_len)
# do the same for the window function
window_sq
=
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frames
)).
unsqueeze
(
0
)
# size (1, n_fft, n_frames)
...
...
@@ -246,67 +185,70 @@ def istft(stft_matrix, # type: Tensor
window_envelop_lowest
=
window_envelop
.
abs
().
min
()
assert
window_envelop_lowest
>
1e-11
,
(
'window overlap add min: %f'
%
(
window_envelop_lowest
))
y
=
(
y
/
window_envelop
).
squeeze
(
1
)
# size (
batch
, expected_signal_len)
y
=
(
y
/
window_envelop
).
squeeze
(
1
)
# size (
channel
, expected_signal_len)
if
stft_matrix_dim
==
3
:
# remove the
batch
dimension
if
stft_matrix_dim
==
3
:
# remove the
channel
dimension
y
=
y
.
squeeze
(
0
)
return
y
@
torch
.
jit
.
script
def
spectrogram
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
def
spectrogram
(
waveform
,
pad
,
window
,
n_fft
,
hop
_length
,
win_length
,
power
,
normalize
d
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
r
"""Create a spectrogram from a raw audio signal.
Args:
sig
(torch.Tensor): Tensor of audio of size (c, n)
waveform
(torch.Tensor): Tensor of audio of size (c, n)
pad (int): Two sided padding of signal
window (torch.Tensor): Window
_
tensor
window (torch.Tensor): Window
tensor
that is applied/multiplied to each frame/window
n_fft (int): Size of fft
hop (int): Length of hop between STFT windows
w
s
(int): Window size
power (int)
: Exponent for the magnitude spectrogram,
hop
_length
(int): Length of hop between STFT windows
w
in_length
(int): Window size
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalize (bool)
: Whether to normalize by magnitude after stft
normalize
d
(bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
fourier bins, and time is the number of window hops (n_frames).
"""
assert
sig
.
dim
()
==
2
assert
waveform
.
dim
()
==
2
if
pad
>
0
:
# TODO add "with torch.no_grad():" back when JIT supports it
sig
=
torch
.
nn
.
functional
.
pad
(
sig
,
(
pad
,
pad
),
"constant"
)
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
pad
,
pad
),
"constant"
)
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f
=
_stft
(
sig
,
n_fft
,
hop
,
ws
,
window
,
True
,
'reflect'
,
False
,
True
)
.
transpose
(
1
,
2
)
spec_f
=
_stft
(
waveform
,
n_fft
,
hop
_length
,
win_length
,
window
,
True
,
'reflect'
,
False
,
True
)
if
normalize
:
if
normalize
d
:
spec_f
/=
window
.
pow
(
2
).
sum
().
sqrt
()
spec_f
=
spec_f
.
pow
(
power
).
sum
(
-
1
)
# get power of "complex" tensor
(c, l, n_fft)
spec_f
=
spec_f
.
pow
(
power
).
sum
(
-
1
)
# get power of "complex" tensor
return
spec_f
@
torch
.
jit
.
script
def
create_fb_matrix
(
n_
stft
,
f_min
,
f_max
,
n_mels
):
def
create_fb_matrix
(
n_
freqs
,
f_min
,
f_max
,
n_mels
):
# type: (int, float, float, int) -> Tensor
r
""" Create a frequency bin conversion matrix.
Args:
n_
stft
(int): Number of f
ilter banks from spectrogram
n_
freqs
(int): Number of f
requencies to highlight/apply
f_min (float): Minimum frequency
f_max (float): Maximum frequency
n_mels (int): Number of mel
bin
s
n_mels (int): Number of mel
filterbank
s
Returns:
torch.Tensor: Triangular filter banks (fb matrix)
torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., `n_freqs`), the applied result would be
`A * create_fb_matrix(A.size(-1), ...)`.
"""
#
get stft
freq bins
stft_
freqs
=
torch
.
linspace
(
f_min
,
f_max
,
n_
stft
)
# freq bins
freqs
=
torch
.
linspace
(
f_min
,
f_max
,
n_
freqs
)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min
=
0.
if
f_min
==
0
else
2595.
*
math
.
log10
(
1.
+
(
f_min
/
700.
))
...
...
@@ -316,17 +258,17 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
f_pts
=
700.
*
(
10
**
(
m_pts
/
2595.
)
-
1.
)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_mels + 1)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
stft_
freqs
.
unsqueeze
(
1
)
# (n_
stft
, n_mels + 2)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
freqs
.
unsqueeze
(
1
)
# (n_
freqs
, n_mels + 2)
# create overlapping triangles
z
=
torch
.
zeros
(
1
)
down_slopes
=
(
-
1.
*
slopes
[:,
:
-
2
])
/
f_diff
[:
-
1
]
# (n_
stft
, n_mels)
up_slopes
=
slopes
[:,
2
:]
/
f_diff
[
1
:]
# (n_
stft
, n_mels)
fb
=
torch
.
max
(
z
,
torch
.
min
(
down_slopes
,
up_slopes
))
z
ero
=
torch
.
zeros
(
1
)
down_slopes
=
(
-
1.
*
slopes
[:,
:
-
2
])
/
f_diff
[:
-
1
]
# (n_
freqs
, n_mels)
up_slopes
=
slopes
[:,
2
:]
/
f_diff
[
1
:]
# (n_
freqs
, n_mels)
fb
=
torch
.
max
(
z
ero
,
torch
.
min
(
down_slopes
,
up_slopes
))
return
fb
@
torch
.
jit
.
script
def
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
=
None
):
def
spectrogram_to_DB
(
spec
gram
,
multiplier
,
amin
,
db_multiplier
,
top_db
=
None
):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
...
...
@@ -335,72 +277,57 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
a full clip.
Args:
spec (torch.Tensor): Normal STFT
spec
gram
(torch.Tensor): Normal STFT
of size (c, f, t)
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp spec
amin (float): Number to clamp spec
gram
db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels.
A reasonable number
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80.
Returns:
torch.Tensor: Spectrogram in DB
torch.Tensor: Spectrogram in DB
of size (c, f, t)
"""
spec_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
spec
,
min
=
amin
))
spec_db
-=
multiplier
*
db_multiplier
spec
gram
_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
spec
gram
,
min
=
amin
))
spec
gram
_db
-=
multiplier
*
db_multiplier
if
top_db
is
not
None
:
new_spec_db_max
=
torch
.
tensor
(
float
(
spec_db
.
max
())
-
top_db
,
dtype
=
spec_db
.
dtype
,
device
=
spec_db
.
device
)
spec_db
=
torch
.
max
(
spec_db
,
new_spec_db_max
)
new_spec_db_max
=
torch
.
tensor
(
float
(
specgram_db
.
max
())
-
top_db
,
dtype
=
specgram_db
.
dtype
,
device
=
specgram_db
.
device
)
specgram_db
=
torch
.
max
(
specgram_db
,
new_spec_db_max
)
return
spec_db
return
spec
gram
_db
@
torch
.
jit
.
script
def
create_dct
(
n_mfcc
,
n_mels
,
norm
):
# type: (int, int, Optional[str]) -> Tensor
r
"""Creates a DCT transformation matrix with shape (n
um
_mels, n
um
_mfcc),
r
"""Creates a DCT transformation matrix with shape (
`
n_mels
`
,
`
n_mfcc
`
),
normalized depending on norm.
Args:
n_mfcc (int)
: Number of mfc coefficients to retain
n_mels (int): Number of
MEL bin
s
norm (Optional[str])
: Norm to use (either 'ortho' or None)
n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of
mel filterbank
s
norm (Optional[str]): Norm to use (either 'ortho' or None)
Returns:
torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data.
torch.Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (`n_mels`, `n_mfcc`).
"""
outdim
=
n_mfcc
dim
=
n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n
=
torch
.
arange
(
dim
)
k
=
torch
.
arange
(
outdim
)[:,
None
]
dct
=
torch
.
cos
(
math
.
pi
/
float
(
dim
)
*
(
n
+
0.5
)
*
k
)
n
=
torch
.
arange
(
float
(
n_mels
)
)
k
=
torch
.
arange
(
float
(
n_mfcc
)).
unsqueeze
(
1
)
dct
=
torch
.
cos
(
math
.
pi
/
float
(
n_mels
)
*
(
n
+
0.5
)
*
k
)
# size (n_mfcc, n_mels)
if
norm
is
None
:
dct
*=
2.0
else
:
assert
norm
==
'ortho'
dct
[
0
]
*=
1.0
/
math
.
sqrt
(
2.0
)
dct
*=
math
.
sqrt
(
2.0
/
float
(
dim
))
dct
*=
math
.
sqrt
(
2.0
/
float
(
n_mels
))
return
dct
.
t
()
@
torch
.
jit
.
script
def
BLC2CBL
(
tensor
):
# type: (Tensor) -> Tensor
r
"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length.
Args:
tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c)
Returns:
torch.Tensor: Tensor of spectrogram with shape (c, b, l)
"""
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
@
torch
.
jit
.
script
def
mu_law_encoding
(
x
,
qc
):
def
mu_law_encoding
(
x
,
quantization_channels
):
# type: (Tensor, int) -> Tensor
r
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
...
@@ -410,13 +337,12 @@ def mu_law_encoding(x, qc):
Args:
x (torch.Tensor): Input tensor
q
c
(int): Number of channels
(i.e. quantization channels)
q
uantization_channels
(int): Number of channels
Returns:
torch.Tensor: Input after mu-law companding
"""
assert
isinstance
(
x
,
torch
.
Tensor
),
'mu_law_encoding expects a Tensor'
mu
=
qc
-
1.
mu
=
quantization_channels
-
1.
if
not
x
.
is_floating_point
():
x
=
x
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x
.
dtype
)
...
...
@@ -427,7 +353,7 @@ def mu_law_encoding(x, qc):
@
torch
.
jit
.
script
def
mu_law_expanding
(
x_mu
,
q
c
):
def
mu_law_expanding
(
x_mu
,
q
uantization_channels
):
# type: (Tensor, int) -> Tensor
r
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
...
@@ -437,13 +363,12 @@ def mu_law_expanding(x_mu, qc):
Args:
x_mu (torch.Tensor): Input tensor
q
c
(int): Number of channels
(i.e. quantization channels)
q
uantization_channels
(int): Number of channels
Returns:
torch.Tensor: Input after decoding
"""
assert
isinstance
(
x_mu
,
torch
.
Tensor
),
'mu_law_expanding expects a Tensor'
mu
=
qc
-
1.
mu
=
quantization_channels
-
1.
if
not
x_mu
.
is_floating_point
():
x_mu
=
x_mu
.
to
(
torch
.
float
)
mu
=
torch
.
tensor
(
mu
,
dtype
=
x_mu
.
dtype
)
...
...
@@ -452,71 +377,15 @@ def mu_law_expanding(x_mu, qc):
return
x
def
stft
(
waveforms
,
fft_length
,
hop_length
=
None
,
win_length
=
None
,
window
=
None
,
center
=
True
,
pad_mode
=
'reflect'
,
normalized
=
False
,
onesided
=
True
):
"""Compute a short time Fourier transform of the input waveform(s).
It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3.
It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window.
Args:
waveforms (torch.Tensor): Audio signal of size `(*, channel, time)`
fft_length (int): FFT size [sample].
hop_length (int): Hop size [sample] between STFT frames.
(Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`).
win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`).
window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`).
center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered
at time `t * hop_length`. (Defaults to `True` by `torch.stft`)
pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`).
normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`).
onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT
of real-valued signal. (Defaults to `True` by `torch.stft`).
Returns:
torch.Tensor: `(*, channel, num_freqs, time, complex=2)`
Example:
>>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time)
>>> x = stft(waveforms, 2048, 512)
>>> x.shape
torch.Size([16, 2, 1025, 20])
"""
leading_dims
=
waveforms
.
shape
[:
-
1
]
waveforms
=
waveforms
.
reshape
(
-
1
,
waveforms
.
size
(
-
1
))
if
window
is
None
:
if
win_length
is
None
:
window
=
torch
.
hann_window
(
fft_length
)
else
:
window
=
torch
.
hann_window
(
win_length
)
complex_specgrams
=
torch
.
stft
(
waveforms
,
n_fft
=
fft_length
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
,
normalized
=
normalized
,
onesided
=
onesided
)
complex_specgrams
=
complex_specgrams
.
reshape
(
leading_dims
+
complex_specgrams
.
shape
[
1
:])
return
complex_specgrams
def
complex_norm
(
complex_tensor
,
power
=
1.0
):
"""Compute the norm of complex tensor input
r
"""Compute the norm of complex tensor input
.
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. Default
s to
`1.0`.
complex_tensor (
torch.
Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm.
(
Default
:
`1.0`
)
.
Returns:
Tensor:
p
ower of the normed input tensor
, s
hape of `(*, )`
torch.
Tensor:
P
ower of the normed input tensor
. S
hape of `(*, )`
"""
if
power
==
1.0
:
return
torch
.
norm
(
complex_tensor
,
2
,
-
1
)
...
...
@@ -524,16 +393,26 @@ def complex_norm(complex_tensor, power=1.0):
def
angle
(
complex_tensor
):
"""
Return angle of a complex tensor with shape (*, 2).
r
"""Compute the angle of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
"""
return
torch
.
atan2
(
complex_tensor
[...,
1
],
complex_tensor
[...,
0
])
def
magphase
(
complex_tensor
,
power
=
1.
):
"""
Separate a complex-valued spectrogram with shape (*,2)
into its magnitude and phase.
r
"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. (Default: `1.0`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor
"""
mag
=
complex_norm
(
complex_tensor
,
power
)
phase
=
angle
(
complex_tensor
)
...
...
@@ -541,20 +420,16 @@ def magphase(complex_tensor, power=1.):
def
phase_vocoder
(
complex_specgrams
,
rate
,
phase_advance
):
"""
Phase vocoder. Given a STFT tensor, speed up in time
without modifying pitch by a factor of `rate`.
r
"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of `rate`.
Args:
complex_specgrams (Tensor):
(*, channel, num_freqs, time, complex=2)
rate (float): Speed-up factor.
phase_advance (Tensor): Expected phase advance in
each bin. (num_freqs, 1).
complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2)
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1)
Returns:
complex_specgrams_stretch (Tensor):
(*, channel, num_freqs, ceil(time/rate), complex=2).
complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2)
Example:
>>> num_freqs, hop_length = 1025, 512
...
...
torchaudio/transforms.py
View file @
b29a4639
...
...
@@ -7,314 +7,205 @@ from . import functional as F
from
.compliance
import
kaldi
# TODO remove this class
class
Compose
(
object
):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.Scale(),
>>> transforms.PadTrim(max_len=16000),
>>> ])
"""
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
audio
):
for
t
in
self
.
transforms
:
audio
=
t
(
audio
)
return
audio
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
format_string
+=
'
\n
'
format_string
+=
' {0}'
.
format
(
t
)
format_string
+=
'
\n
)'
return
format_string
class
Scale
(
torch
.
jit
.
ScriptModule
):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (int): maximum value of input tensor. default: 16-bit depth
"""
__constants__
=
[
'factor'
]
def
__init__
(
self
,
factor
=
2
**
31
):
super
(
Scale
,
self
).
__init__
()
self
.
factor
=
factor
@
torch
.
jit
.
script_method
def
forward
(
self
,
tensor
):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
return
F
.
scale
(
tensor
,
self
.
factor
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
PadTrim
(
torch
.
jit
.
ScriptModule
):
"""Pad/Trim a 2
d-T
ensor
(Signal or Labels)
r
"""Pad/Trim a 2
D t
ensor
Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
max_len (int): Length to which the tensor will be padded
channels_first (bool): Pad for channels first tensors. Default: `True`
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
"""
__constants__
=
[
'max_len'
,
'fill_value'
,
'len_dim'
,
'ch_dim'
]
__constants__
=
[
'max_len'
,
'fill_value'
]
def
__init__
(
self
,
max_len
,
fill_value
=
0.
,
channels_first
=
True
):
def
__init__
(
self
,
max_len
,
fill_value
=
0.
):
super
(
PadTrim
,
self
).
__init__
()
self
.
max_len
=
max_len
self
.
fill_value
=
fill_value
self
.
len_dim
,
self
.
ch_dim
=
int
(
channels_first
),
int
(
not
channels_first
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
tensor
):
"""
Returns:
Tensor: (c x n) or (n x c)
"""
return
F
.
pad_trim
(
tensor
,
self
.
ch_dim
,
self
.
max_len
,
self
.
len_dim
,
self
.
fill_value
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'(max_len={0})'
.
format
(
self
.
max_len
)
class
DownmixMono
(
torch
.
jit
.
ScriptModule
):
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns:
tensor (Tensor) (Samples x 1):
"""
__constants__
=
[
'ch_dim'
]
def
__init__
(
self
,
channels_first
=
None
):
super
(
DownmixMono
,
self
).
__init__
()
self
.
ch_dim
=
int
(
not
channels_first
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
tensor
):
return
F
.
downmix_mono
(
tensor
,
self
.
ch_dim
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
class
LC2CL
(
torch
.
jit
.
ScriptModule
):
"""Permute a 2d tensor from samples (n x c) to (c x n)
"""
def
__init__
(
self
):
super
(
LC2CL
,
self
).
__init__
()
@
torch
.
jit
.
script_method
def
forward
(
self
,
tensor
):
"""
def
forward
(
self
,
waveform
):
r
"""
Args:
tens
or (Tensor): Tensor of audio
signal with shape (LxC
)
wavef
or
m
(
torch.
Tensor): Tensor of audio
of size (c, n
)
Returns:
tensor (
Tensor
)
: Tensor of
audio signal with shape (CxL
)
Tensor: Tensor of
size (c, `max_len`
)
"""
return
F
.
LC2CL
(
tensor
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
def
SPECTROGRAM
(
*
args
,
**
kwargs
):
warn
(
"SPECTROGRAM has been renamed to Spectrogram"
)
return
Spectrogram
(
*
args
,
**
kwargs
)
return
F
.
pad_trim
(
waveform
,
self
.
max_len
,
self
.
fill_value
)
class
Spectrogram
(
torch
.
jit
.
ScriptModule
):
"""Create a spectrogram from a
raw
audio signal
r
"""Create a spectrogram from a audio signal
Args:
n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins
ws (int): window size. default: n_fft
hop (int, optional): length of hop between STFT windows. default: ws // 2
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
power (int > 0 ) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft
wkwargs (dict, optional): arguments for window function
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
pad (int): Two sided padding of signal. (Default: 0)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
"""
__constants__
=
[
'n_fft'
,
'w
s'
,
'hop
'
,
'pad'
,
'power'
,
'normalize'
]
__constants__
=
[
'n_fft'
,
'w
in_length'
,
'hop_length
'
,
'pad'
,
'power'
,
'normalize
d
'
]
def
__init__
(
self
,
n_fft
=
400
,
w
s
=
None
,
hop
=
None
,
pad
=
0
,
window
=
torch
.
hann_window
,
power
=
2
,
normalize
=
False
,
wkwargs
=
None
):
def
__init__
(
self
,
n_fft
=
400
,
w
in_length
=
None
,
hop
_length
=
None
,
pad
=
0
,
window
_fn
=
torch
.
hann_window
,
power
=
2
,
normalize
d
=
False
,
wkwargs
=
None
):
super
(
Spectrogram
,
self
).
__init__
()
self
.
n_fft
=
n_fft
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self
.
w
s
=
ws
if
ws
is
not
None
else
n_fft
self
.
hop
=
hop
if
hop
is
not
None
else
self
.
w
s
//
2
window
=
window
(
self
.
w
s
)
if
wkwargs
is
None
else
window
(
self
.
w
s
,
**
wkwargs
)
self
.
w
in_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop
_length
=
hop_length
if
hop_length
is
not
None
else
self
.
w
in_length
//
2
window
=
window
_fn
(
self
.
w
in_length
)
if
wkwargs
is
None
else
window
_fn
(
self
.
w
in_length
,
**
wkwargs
)
self
.
window
=
torch
.
jit
.
Attribute
(
window
,
torch
.
Tensor
)
self
.
pad
=
pad
self
.
power
=
power
self
.
normalize
=
normalize
self
.
normalize
d
=
normalize
d
@
torch
.
jit
.
script_method
def
forward
(
self
,
sig
):
"""
def
forward
(
self
,
waveform
):
r
"""
Args:
sig (
Tensor): Tensor of audio of size (c, n)
waveform (torch.
Tensor): Tensor of audio of size (c, n)
Returns:
spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
fourier bins, and time is the number of window hops (n_frames).
"""
return
F
.
spectrogram
(
sig
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop
,
self
.
ws
,
self
.
power
,
self
.
normalize
)
def
F2M
(
*
args
,
**
kwargs
):
warn
(
"F2M has been renamed to MelScale"
)
return
MelScale
(
*
args
,
**
kwargs
)
return
F
.
spectrogram
(
waveform
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
power
,
self
.
normalized
)
class
MelScale
(
torch
.
jit
.
ScriptModule
):
"""This turns a normal STFT into a mel frequency STFT, using a conversion
r
"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int):
n
umber of mel
bins
s
r
(int):
s
ample rate of audio signal
f_m
ax
(float
, optional): max
imum frequency.
d
efault:
`sr` // 2
f_m
in
(float
): min
imum frequency.
d
efault:
0
n_stft (int, optional):
n
umber of
filter banks from stft
. Calculated from first input
n_mels (int):
N
umber of mel
filterbanks. (Default: 128)
s
ample_rate
(int):
S
ample rate of audio signal
. (Default: 16000)
f_m
in
(float
): Min
imum frequency.
(D
efault:
0.)
f_m
ax
(float
, optional): Max
imum frequency.
(D
efault:
`sample_rate // 2`)
n_stft (int, optional):
N
umber of
bins in STFT
. Calculated from first input
if `None` is given. See `n_fft` in `Spectrogram`.
"""
__constants__
=
[
'n_mels'
,
's
r
'
,
'f_min'
,
'f_max'
]
__constants__
=
[
'n_mels'
,
's
ample_rate
'
,
'f_min'
,
'f_max'
]
def
__init__
(
self
,
n_mels
=
128
,
s
r
=
16000
,
f_m
ax
=
None
,
f_min
=
0.
,
n_stft
=
None
):
def
__init__
(
self
,
n_mels
=
128
,
s
ample_rate
=
16000
,
f_m
in
=
0.
,
f_max
=
None
,
n_stft
=
None
):
super
(
MelScale
,
self
).
__init__
()
self
.
n_mels
=
n_mels
self
.
sr
=
sr
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sr
//
2
)
self
.
sample_rate
=
sample_rate
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
assert
f_min
<=
self
.
f_max
,
'Require f_min: %f < f_max: %f'
%
(
f_min
,
self
.
f_max
)
self
.
f_min
=
f_min
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
)
self
.
fb
=
torch
.
jit
.
Attribute
(
fb
,
torch
.
Tensor
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
spec_f
):
def
forward
(
self
,
specgram
):
r
"""
Args:
specgram (torch.Tensor): a spectrogram STFT of size (c, f, t)
Returns:
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
if
self
.
fb
.
numel
()
==
0
:
tmp_fb
=
F
.
create_fb_matrix
(
spec
_f
.
size
(
2
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
)
tmp_fb
=
F
.
create_fb_matrix
(
spec
gram
.
size
(
1
),
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
)
# Attributes cannot be reassigned outside __init__ so workaround
self
.
fb
.
resize_
(
tmp_fb
.
size
())
self
.
fb
.
copy_
(
tmp_fb
)
spec_m
=
torch
.
matmul
(
spec_f
,
self
.
fb
)
# (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return
spec_m
# (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...)
mel_specgram
=
torch
.
matmul
(
specgram
.
transpose
(
1
,
2
),
self
.
fb
).
transpose
(
1
,
2
)
return
mel_specgram
class
SpectrogramToDB
(
torch
.
jit
.
ScriptModule
):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
stype (str): scale of input spectrogram (
"
power
"
or
"
magnitude
"
).
The
power being the elementwise square of the magnitude.
d
efault:
"
power
"
stype (str): scale of input spectrogram (
'
power
'
or
'
magnitude
'
). The
power being the elementwise square of the magnitude.
(D
efault:
'
power
')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80.
"""
__constants__
=
[
'multiplier'
,
'amin'
,
'ref_value'
,
'db_multiplier'
]
def
__init__
(
self
,
stype
=
"
power
"
,
top_db
=
None
):
def
__init__
(
self
,
stype
=
'
power
'
,
top_db
=
None
):
super
(
SpectrogramToDB
,
self
).
__init__
()
self
.
stype
=
torch
.
jit
.
Attribute
(
stype
,
str
)
if
top_db
is
not
None
and
top_db
<
0
:
raise
ValueError
(
'top_db must be positive value'
)
self
.
top_db
=
torch
.
jit
.
Attribute
(
top_db
,
Optional
[
float
])
self
.
multiplier
=
10.
if
stype
==
"
power
"
else
20.
self
.
multiplier
=
10.
0
if
stype
==
'
power
'
else
20.
0
self
.
amin
=
1e-10
self
.
ref_value
=
1.
self
.
ref_value
=
1.
0
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
@
torch
.
jit
.
script_method
def
forward
(
self
,
spec
):
# numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
return
F
.
spectrogram_to_DB
(
spec
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
def
forward
(
self
,
specgram
):
r
"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
specgram (torch.Tensor): STFT of size (c, f, t)
Returns:
torch.Tensor: STFT after changing scale of size (c, f, t)
"""
return
F
.
spectrogram_to_DB
(
specgram
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
class
MFCC
(
torch
.
jit
.
ScriptModule
):
"""Create the Mel-frequency cepstrum coefficients from an audio signal
r
"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
s
r
(int)
:
s
ample rate of audio signal
n_mfcc (int)
:
n
umber of mfc coefficients to retain
dct_type (int)
: type of DCT (discrete cosine transform) to use
norm (string, optional)
: norm to use
log_mels (bool)
: whether to use log-mel spectrograms instead of db-scaled
Args:
s
ample_rate
(int):
S
ample rate of audio signal
. (Default: 16000)
n_mfcc (int):
N
umber of mfc coefficients to retain
dct_type (int): type of DCT (discrete cosine transform) to use
norm (string, optional): norm to use
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
melkwargs (dict, optional): arguments for MelSpectrogram
"""
__constants__
=
[
's
r
'
,
'n_mfcc'
,
'dct_type'
,
'top_db'
,
'log_mels'
]
__constants__
=
[
's
ample_rate
'
,
'n_mfcc'
,
'dct_type'
,
'top_db'
,
'log_mels'
]
def
__init__
(
self
,
s
r
=
16000
,
n_mfcc
=
40
,
dct_type
=
2
,
norm
=
'ortho'
,
log_mels
=
False
,
def
__init__
(
self
,
s
ample_rate
=
16000
,
n_mfcc
=
40
,
dct_type
=
2
,
norm
=
'ortho'
,
log_mels
=
False
,
melkwargs
=
None
):
super
(
MFCC
,
self
).
__init__
()
supported_dct_types
=
[
2
]
if
dct_type
not
in
supported_dct_types
:
raise
ValueError
(
'DCT type not supported'
.
format
(
dct_type
))
self
.
s
r
=
sr
self
.
s
ample_rate
=
sample_rate
self
.
n_mfcc
=
n_mfcc
self
.
dct_type
=
dct_type
self
.
norm
=
torch
.
jit
.
Attribute
(
norm
,
Optional
[
str
])
self
.
top_db
=
80.
self
.
s
2db
=
SpectrogramToDB
(
"
power
"
,
self
.
top_db
)
self
.
top_db
=
80.
0
self
.
s
pectrogram_to_DB
=
SpectrogramToDB
(
'
power
'
,
self
.
top_db
)
if
melkwargs
is
not
None
:
self
.
MelSpectrogram
=
MelSpectrogram
(
s
r
=
self
.
sr
,
**
melkwargs
)
self
.
MelSpectrogram
=
MelSpectrogram
(
s
ample_rate
=
self
.
sample_rate
,
**
melkwargs
)
else
:
self
.
MelSpectrogram
=
MelSpectrogram
(
s
r
=
self
.
sr
)
self
.
MelSpectrogram
=
MelSpectrogram
(
s
ample_rate
=
self
.
sample_rate
)
if
self
.
n_mfcc
>
self
.
MelSpectrogram
.
n_mels
:
raise
ValueError
(
'Cannot select more MFCC coefficients than # mel bins'
)
...
...
@@ -323,29 +214,28 @@ class MFCC(torch.jit.ScriptModule):
self
.
log_mels
=
log_mels
@
torch
.
jit
.
script_method
def
forward
(
self
,
sig
):
"""
def
forward
(
self
,
waveform
):
r
"""
Args:
sig (
Tensor): Tensor of audio of size (c
hannels [c], samples [n]
)
waveform (torch.
Tensor): Tensor of audio of size (c
, n
)
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t)
"""
mel_spec
t
=
self
.
MelSpectrogram
(
sig
)
mel_spec
gram
=
self
.
MelSpectrogram
(
waveform
)
if
self
.
log_mels
:
log_offset
=
1e-6
mel_spec
t
=
torch
.
log
(
mel_spec
t
+
log_offset
)
mel_spec
gram
=
torch
.
log
(
mel_spec
gram
+
log_offset
)
else
:
mel_spect
=
self
.
s2db
(
mel_spect
)
mfcc
=
torch
.
matmul
(
mel_spect
,
self
.
dct_mat
)
mel_specgram
=
self
.
spectrogram_to_DB
(
mel_specgram
)
# (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...)
mfcc
=
torch
.
matmul
(
mel_specgram
.
transpose
(
1
,
2
),
self
.
dct_mat
).
transpose
(
1
,
2
)
return
mfcc
class
MelSpectrogram
(
torch
.
jit
.
ScriptModule
):
"""Create M
EL
Spectrogram
s
fr
om
a raw audio signal
using the stft
function in PyTorch
.
r
"""Create M
el
Spectrogram f
o
r a raw audio signal
. This is a composition of Spectrogram
and MelScale
.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
...
...
@@ -353,87 +243,58 @@ class MelSpectrogram(torch.jit.ScriptModule):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sr (int): sample rate of audio signal
ws (int): window size
hop (int, optional): length of hop between STFT windows. default: `ws` // 2
n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: `torch.hann_window`
wkwargs (dict, optional): arguments for window function
sample_rate (int): Sample rate of audio signal. (Default: 16000)
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
f_min (float): Minimum frequency. (Default: 0.)
f_max (float, optional): Maximum frequency. (Default: `None`)
pad (int): Two sided padding of signal. (Default: 0)
n_mels (int): Number of mel filterbanks. (Default: 128)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
Example:
>>>
sig, sr
= torchaudio.load(
"
test.wav
"
, normalization=True)
>>>
spec_mel
= transforms.MelSpectrogram(s
r)(sig
) # (c,
l
,
m
)
>>>
waveform, sample_rate
= torchaudio.load(
'
test.wav
'
, normalization=True)
>>>
mel_specgram
= transforms.MelSpectrogram(s
ample_rate)(waveform
) # (c,
n_mels
,
t
)
"""
__constants__
=
[
's
r
'
,
'n_fft'
,
'w
s'
,
'hop
'
,
'pad'
,
'n_mels'
,
'f_min'
]
__constants__
=
[
's
ample_rate
'
,
'n_fft'
,
'w
in_length'
,
'hop_length
'
,
'pad'
,
'n_mels'
,
'f_min'
]
def
__init__
(
self
,
s
r
=
16000
,
n_fft
=
400
,
w
s
=
None
,
hop
=
None
,
f_min
=
0.
,
f_max
=
None
,
pad
=
0
,
n_mels
=
128
,
window
=
torch
.
hann_window
,
wkwargs
=
None
):
def
__init__
(
self
,
s
ample_rate
=
16000
,
n_fft
=
400
,
w
in_length
=
None
,
hop
_length
=
None
,
f_min
=
0.
,
f_max
=
None
,
pad
=
0
,
n_mels
=
128
,
window
_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
s
r
=
sr
self
.
s
ample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
w
s
=
ws
if
ws
is
not
None
else
n_fft
self
.
hop
=
hop
if
hop
is
not
None
else
self
.
w
s
//
2
self
.
w
in_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop
_length
=
hop_length
if
hop_length
is
not
None
else
self
.
w
in_length
//
2
self
.
pad
=
pad
self
.
n_mels
=
n_mels
# number of mel frequency bins
self
.
f_max
=
torch
.
jit
.
Attribute
(
f_max
,
Optional
[
float
])
self
.
f_min
=
f_min
self
.
spec
=
Spectrogram
(
n_fft
=
self
.
n_fft
,
ws
=
self
.
ws
,
hop
=
self
.
hop
,
pad
=
self
.
pad
,
window
=
window
,
power
=
2
,
normalize
=
False
,
wkwargs
=
wkwargs
)
self
.
fm
=
MelScale
(
self
.
n_mels
,
self
.
sr
,
self
.
f_max
,
self
.
f_min
)
self
.
spectrogram
=
Spectrogram
(
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
2
,
normalized
=
False
,
wkwargs
=
wkwargs
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
sig
):
"""
def
forward
(
self
,
waveform
):
r
"""
Args:
sig (
Tensor): Tensor of audio of size (c
hannels [c], samples [n]
)
waveform (torch.
Tensor): Tensor of audio of size (c
, n
)
Returns:
spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
spec
=
self
.
spec
(
sig
)
spec_mel
=
self
.
fm
(
spec
)
return
spec_mel
def
MEL
(
*
args
,
**
kwargs
):
raise
DeprecationWarning
(
"MEL has been removed from the library please use MelSpectrogram or librosa"
)
class
BLC2CBL
(
torch
.
jit
.
ScriptModule
):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
"""
def
__init__
(
self
):
super
(
BLC2CBL
,
self
).
__init__
()
@
torch
.
jit
.
script_method
def
forward
(
self
,
tensor
):
"""
Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
Returns:
tensor (Tensor): Tensor of spectrogram with shape (CxBxL)
"""
return
F
.
BLC2CBL
(
tensor
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
specgram
=
self
.
spectrogram
(
waveform
)
mel_specgram
=
self
.
mel_scale
(
specgram
)
return
mel_specgram
class
MuLawEncoding
(
torch
.
jit
.
ScriptModule
):
"""Encode signal based on mu-law companding. For more info see the
r
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
...
...
@@ -441,33 +302,27 @@ class MuLawEncoding(torch.jit.ScriptModule):
Args:
quantization_channels (int): Number of channels. default: 256
"""
__constants__
=
[
'q
c
'
]
__constants__
=
[
'q
uantization_channels
'
]
def
__init__
(
self
,
quantization_channels
=
256
):
super
(
MuLawEncoding
,
self
).
__init__
()
self
.
q
c
=
quantization_channels
self
.
q
uantization_channels
=
quantization_channels
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
"""
r
"""
Args:
x (
FloatTensor/LongTensor)
x (
torch.Tensor): A signal to be encoded
Returns:
x_mu (LongTensor)
x_mu (torch.Tensor): An encoded signal
"""
return
F
.
mu_law_encoding
(
x
,
self
.
qc
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
return
F
.
mu_law_encoding
(
x
,
self
.
quantization_channels
)
class
MuLawExpanding
(
torch
.
jit
.
ScriptModule
):
"""Decode mu-law encoded signal. For more info see the
r
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
...
...
@@ -475,33 +330,27 @@ class MuLawExpanding(torch.jit.ScriptModule):
Args:
quantization_channels (int): Number of channels. default: 256
"""
__constants__
=
[
'q
c
'
]
__constants__
=
[
'q
uantization_channels
'
]
def
__init__
(
self
,
quantization_channels
=
256
):
super
(
MuLawExpanding
,
self
).
__init__
()
self
.
q
c
=
quantization_channels
self
.
q
uantization_channels
=
quantization_channels
@
torch
.
jit
.
script_method
def
forward
(
self
,
x_mu
):
"""
r
"""
Args:
x_mu (Tensor)
x_mu (
torch.
Tensor)
: A mu-law encoded signal which needs to be decoded
Returns:
x (Tensor)
torch.Tensor: The signal decoded
"""
return
F
.
mu_law_expanding
(
x_mu
,
self
.
qc
)
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'()'
return
F
.
mu_law_expanding
(
x_mu
,
self
.
quantization_channels
)
class
Resample
(
torch
.
nn
.
Module
):
"""Resamples a signal from one frequency to another. A resampling method can
r
"""Resamples a signal from one frequency to another. A resampling method can
be given.
Args:
...
...
@@ -516,15 +365,15 @@ class Resample(torch.nn.Module):
self
.
new_freq
=
new_freq
self
.
resampling_method
=
resampling_method
def
forward
(
self
,
sig
):
"""
def
forward
(
self
,
waveform
):
r
"""
Args:
sig (
Tensor):
t
he input signal of size (c, n)
waveform (torch.
Tensor):
T
he input signal of size (c, n)
Returns:
Tensor:
o
utput signal of size (c, m)
torch.
Tensor:
O
utput signal of size (c, m)
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
kaldi
.
resample_waveform
(
sig
,
self
.
orig_freq
,
self
.
new_freq
)
return
kaldi
.
resample_waveform
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
)
raise
ValueError
(
'Invalid resampling method: %s'
%
(
self
.
resampling_method
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment