Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
95803cf9
Commit
95803cf9
authored
May 16, 2019
by
Jason Lian
Browse files
add docs
parent
23ecb772
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
167 additions
and
5 deletions
+167
-5
torchaudio/functional.py
torchaudio/functional.py
+166
-4
torchaudio/transforms.py
torchaudio/transforms.py
+1
-1
No files found.
torchaudio/functional.py
View file @
95803cf9
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
def
scale
(
tensor
,
factor
):
def
scale
(
tensor
,
factor
):
# type: (Tensor, int) -> Tensor
# type: (Tensor, int) -> Tensor
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
factor (int): Maximum value of input tensor. default: 16-bit depth
Outputs:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if
not
tensor
.
dtype
.
is_floating_point
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
tensor
.
to
(
torch
.
float32
)
...
@@ -11,9 +23,25 @@ def scale(tensor, factor):
...
@@ -11,9 +23,25 @@ def scale(tensor, factor):
def
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
):
def
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
):
# type: (Tensor, int, int, int, float) -> Tensor
# type: (Tensor, int, int, int, float) -> Tensor
"""Pad/Trim a 2d-Tensor (Signal or Labels)
Inputs:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
ch_dim (int): Dimension of channel (not size)
max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size)
fill_value (float): Value to fill in
Outputs:
Tensor: Padded/trimmed tensor
"""
assert
tensor
.
size
(
ch_dim
)
<
128
,
\
assert
tensor
.
size
(
ch_dim
)
<
128
,
\
"Too many channels ({}) detected, see channels_first param."
.
format
(
tensor
.
size
(
ch_dim
))
"Too many channels ({}) detected, see channels_first param."
.
format
(
tensor
.
size
(
ch_dim
))
if
max_len
>
tensor
.
size
(
len_dim
):
if
max_len
>
tensor
.
size
(
len_dim
):
# tuple of (padding_left, padding_right, padding_top, padding_bottom)
# so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding.
padding
=
[
max_len
-
tensor
.
size
(
len_dim
)
padding
=
[
max_len
-
tensor
.
size
(
len_dim
)
if
(
i
%
2
==
1
)
and
(
i
//
2
!=
len_dim
)
if
(
i
%
2
==
1
)
and
(
i
//
2
!=
len_dim
)
else
0
else
0
...
@@ -27,6 +55,15 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
...
@@ -27,6 +55,15 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
def
downmix_mono
(
tensor
,
ch_dim
):
def
downmix_mono
(
tensor
,
ch_dim
):
# type: (Tensor, int) -> Tensor
# type: (Tensor, int) -> Tensor
"""Downmix any stereo signals to mono.
Inputs:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
ch_dim (int): Dimension of channel (not size)
Outputs:
Tensor: Mono signal
"""
if
not
tensor
.
dtype
.
is_floating_point
:
if
not
tensor
.
dtype
.
is_floating_point
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
=
tensor
.
to
(
torch
.
float32
)
...
@@ -36,11 +73,39 @@ def downmix_mono(tensor, ch_dim):
...
@@ -36,11 +73,39 @@ def downmix_mono(tensor, ch_dim):
def
LC2CL
(
tensor
):
def
LC2CL
(
tensor
):
# type: (Tensor) -> Tensor
# type: (Tensor) -> Tensor
"""Permute a 2d tensor from samples (n x c) to (c x n)
Inputs:
tensor (Tensor): Tensor of audio signal with shape (LxC)
Outputs:
Tensor: Tensor of audio signal with shape (CxL)
"""
return
tensor
.
transpose
(
0
,
1
).
contiguous
()
return
tensor
.
transpose
(
0
,
1
).
contiguous
()
def
spectrogram
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
def
spectrogram
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
"""Create a spectrogram from a raw audio signal
Inputs:
sig (Tensor): Tensor of audio of size (c, n)
pad (int): two sided padding of signal
window (Tensor): window_tensor
n_fft (int): size of fft
hop (int): length of hop between STFT windows
ws (int): window size. default: n_fft
power (int > 0 ) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft
Outputs:
Tensor: channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
"""
assert
sig
.
dim
()
==
2
assert
sig
.
dim
()
==
2
if
pad
>
0
:
if
pad
>
0
:
...
@@ -63,8 +128,14 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
...
@@ -63,8 +128,14 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
# type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix.
""" Create a frequency bin conversion matrix.
Arg
s:
Input
s:
n_stft (int): number of filter banks from spectrogram
n_stft (int): number of filter banks from spectrogram
f_min (float): minimum frequency
f_max (float): maximum frequency
n_mels (int): number of mel bins
Outputs:
Tensor: triangular filter banks (fb matrix)
"""
"""
def
_hertz_to_mel
(
f
):
def
_hertz_to_mel
(
f
):
# type: (float) -> Tensor
# type: (float) -> Tensor
...
@@ -94,6 +165,19 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
...
@@ -94,6 +165,19 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
def
mel_scale
(
spec_f
,
f_min
,
f_max
,
n_mels
,
fb
=
None
):
def
mel_scale
(
spec_f
,
f_min
,
f_max
,
n_mels
,
fb
=
None
):
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
""" This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Inputs:
spec_f (Tensor): normal STFT
f_min (float): minimum frequency
f_max (float): maximum frequency
n_mels (int): number of mel bins
fb (Optional[Tensor]): triangular filter banks (fb matrix)
Outputs:
Tuple[Tensor, Tensor]: triangular filter banks (fb matrix) and mel frequency STFT
"""
if
fb
is
None
:
if
fb
is
None
:
fb
=
create_fb_matrix
(
spec_f
.
size
(
2
),
f_min
,
f_max
,
n_mels
).
to
(
spec_f
.
device
)
fb
=
create_fb_matrix
(
spec_f
.
size
(
2
),
f_min
,
f_max
,
n_mels
).
to
(
spec_f
.
device
)
else
:
else
:
...
@@ -103,8 +187,25 @@ def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
...
@@ -103,8 +187,25 @@ def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
return
fb
,
spec_m
return
fb
,
spec_m
def
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
):
def
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
=
None
):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Inputs:
spec (Tensor): normal STFT
multiplier (float): use 10. for power and 20. for amplitude
amin (float): number to clamp spec
db_multiplier (float): log10(max(reference value and amin))
top_db (Optional[float]): minimum negative cut-off in decibels. A reasonable number
is 80.
Outputs:
Tensor: spectrogram in DB
"""
spec_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
spec
,
min
=
amin
))
spec_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
spec
,
min
=
amin
))
spec_db
-=
multiplier
*
db_multiplier
spec_db
-=
multiplier
*
db_multiplier
...
@@ -118,8 +219,14 @@ def create_dct(n_mfcc, n_mels, norm):
...
@@ -118,8 +219,14 @@ def create_dct(n_mfcc, n_mels, norm):
"""
"""
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
normalized depending on norm
normalized depending on norm
Returns:
The transformation matrix, to be right-multiplied to row-wise data.
Inputs:
n_mfcc (int) : number of mfc coefficients to retain
n_mels (int): number of MEL bins
norm (string) : norm to use
Outputs:
Tensor: The transformation matrix, to be right-multiplied to row-wise data.
"""
"""
outdim
=
n_mfcc
outdim
=
n_mfcc
dim
=
n_mels
dim
=
n_mels
...
@@ -137,6 +244,26 @@ def create_dct(n_mfcc, n_mels, norm):
...
@@ -137,6 +244,26 @@ def create_dct(n_mfcc, n_mels, norm):
def
MFCC
(
sig
,
mel_spect
,
log_mels
,
s2db
,
dct_mat
):
def
MFCC
(
sig
,
mel_spect
,
log_mels
,
s2db
,
dct_mat
):
# type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor
# type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor
"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Inputs:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
mel_spect (MelSpectrogram): melspectrogram of sig
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
s2db (SpectrogramToDB): a SpectrogramToDB instance
dct_mat (Tensor): The transformation matrix (dct matrix), to be
right-multiplied to row-wise data
Outputs:
Tensor: Mel-frequency cepstrum coefficients
"""
if
log_mels
:
if
log_mels
:
log_offset
=
1e-6
log_offset
=
1e-6
mel_spect
=
torch
.
log
(
mel_spect
+
log_offset
)
mel_spect
=
torch
.
log
(
mel_spect
+
log_offset
)
...
@@ -148,11 +275,33 @@ def MFCC(sig, mel_spect, log_mels, s2db, dct_mat):
...
@@ -148,11 +275,33 @@ def MFCC(sig, mel_spect, log_mels, s2db, dct_mat):
def
BLC2CBL
(
tensor
):
def
BLC2CBL
(
tensor
):
# type: (Tensor) -> Tensor
# type: (Tensor) -> Tensor
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
Inputs:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
Outputs:
Tensor: Tensor of spectrogram with shape (CxBxL)
"""
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
return
tensor
.
permute
(
2
,
0
,
1
).
contiguous
()
def
mu_law_encoding
(
x
,
qc
):
def
mu_law_encoding
(
x
,
qc
):
# type: (Tensor/ndarray, int) -> Tensor/ndarray
# type: (Tensor/ndarray, int) -> Tensor/ndarray
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1
Inputs:
x (Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
Outputs:
Tensor: Input after mu-law companding
"""
mu
=
qc
-
1.
mu
=
qc
-
1.
if
isinstance
(
x
,
np
.
ndarray
):
if
isinstance
(
x
,
np
.
ndarray
):
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
x_mu
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
...
@@ -169,6 +318,19 @@ def mu_law_encoding(x, qc):
...
@@ -169,6 +318,19 @@ def mu_law_encoding(x, qc):
def
mu_law_expanding
(
x
,
qc
):
def
mu_law_expanding
(
x
,
qc
):
# type: (Tensor/ndarray, int) -> Tensor/ndarray
# type: (Tensor/ndarray, int) -> Tensor/ndarray
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.
Inputs:
x (Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
Outputs:
Tensor: Input after decoding
"""
mu
=
qc
-
1.
mu
=
qc
-
1.
if
isinstance
(
x_mu
,
np
.
ndarray
):
if
isinstance
(
x_mu
,
np
.
ndarray
):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
...
...
torchaudio/transforms.py
View file @
95803cf9
...
@@ -65,7 +65,7 @@ class Scale(object):
...
@@ -65,7 +65,7 @@ class Scale(object):
class
PadTrim
(
object
):
class
PadTrim
(
object
):
"""Pad/Trim a
1
d-Tensor (Signal or Labels)
"""Pad/Trim a
2
d-Tensor (Signal or Labels)
Args:
Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment