Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
3f122ae1
Unverified
Commit
3f122ae1
authored
Jul 25, 2019
by
jamarshon
Committed by
GitHub
Jul 25, 2019
Browse files
Reorder transforms.py and functional.py + add __all__
parent
2f62e573
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
131 additions
and
119 deletions
+131
-119
torchaudio/functional.py
torchaudio/functional.py
+32
-32
torchaudio/transforms.py
torchaudio/transforms.py
+99
-87
No files found.
torchaudio/functional.py
View file @
3f122ae1
...
...
@@ -5,8 +5,8 @@ import torch
__all__
=
[
'istft'
,
'spectrogram'
,
'create_fb_matrix'
,
'spectrogram_to_DB'
,
'create_fb_matrix'
,
'create_dct'
,
'mu_law_encoding'
,
'mu_law_decoding'
,
...
...
@@ -206,6 +206,37 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor
return
spec_f
@
torch
.
jit
.
script
def
spectrogram_to_DB
(
specgram
,
multiplier
,
amin
,
db_multiplier
,
top_db
=
None
):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
specgram (torch.Tensor): Normal STFT of size (c, f, t)
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp specgram
db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80.
Returns:
torch.Tensor: Spectrogram in DB of size (c, f, t)
"""
specgram_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
specgram
,
min
=
amin
))
specgram_db
-=
multiplier
*
db_multiplier
if
top_db
is
not
None
:
new_spec_db_max
=
torch
.
tensor
(
float
(
specgram_db
.
max
())
-
top_db
,
dtype
=
specgram_db
.
dtype
,
device
=
specgram_db
.
device
)
specgram_db
=
torch
.
max
(
specgram_db
,
new_spec_db_max
)
return
specgram_db
@
torch
.
jit
.
script
def
create_fb_matrix
(
n_freqs
,
f_min
,
f_max
,
n_mels
):
# type: (int, float, float, int) -> Tensor
...
...
@@ -244,37 +275,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels):
return
fb
@
torch
.
jit
.
script
def
spectrogram_to_DB
(
specgram
,
multiplier
,
amin
,
db_multiplier
,
top_db
=
None
):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
specgram (torch.Tensor): Normal STFT of size (c, f, t)
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp specgram
db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80.
Returns:
torch.Tensor: Spectrogram in DB of size (c, f, t)
"""
specgram_db
=
multiplier
*
torch
.
log10
(
torch
.
clamp
(
specgram
,
min
=
amin
))
specgram_db
-=
multiplier
*
db_multiplier
if
top_db
is
not
None
:
new_spec_db_max
=
torch
.
tensor
(
float
(
specgram_db
.
max
())
-
top_db
,
dtype
=
specgram_db
.
dtype
,
device
=
specgram_db
.
device
)
specgram_db
=
torch
.
max
(
specgram_db
,
new_spec_db_max
)
return
specgram_db
@
torch
.
jit
.
script
def
create_dct
(
n_mfcc
,
n_mels
,
norm
):
# type: (int, int, Optional[str]) -> Tensor
...
...
torchaudio/transforms.py
View file @
3f122ae1
...
...
@@ -7,6 +7,18 @@ from . import functional as F
from
.compliance
import
kaldi
__all__
=
[
'Spectrogram'
,
'SpectrogramToDB'
,
'MelScale'
,
'MelSpectrogram'
,
'MFCC'
,
'MuLawEncoding'
,
'MuLawDecoding'
,
'Resample'
,
]
class
Spectrogram
(
torch
.
jit
.
ScriptModule
):
r
"""Create a spectrogram from a audio signal
...
...
@@ -55,6 +67,46 @@ class Spectrogram(torch.jit.ScriptModule):
self
.
win_length
,
self
.
power
,
self
.
normalized
)
class
SpectrogramToDB
(
torch
.
jit
.
ScriptModule
):
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
stype (str): scale of input spectrogram ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: 'power')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80.
"""
__constants__
=
[
'multiplier'
,
'amin'
,
'ref_value'
,
'db_multiplier'
]
def
__init__
(
self
,
stype
=
'power'
,
top_db
=
None
):
super
(
SpectrogramToDB
,
self
).
__init__
()
self
.
stype
=
torch
.
jit
.
Attribute
(
stype
,
str
)
if
top_db
is
not
None
and
top_db
<
0
:
raise
ValueError
(
'top_db must be positive value'
)
self
.
top_db
=
torch
.
jit
.
Attribute
(
top_db
,
Optional
[
float
])
self
.
multiplier
=
10.0
if
stype
==
'power'
else
20.0
self
.
amin
=
1e-10
self
.
ref_value
=
1.0
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
@
torch
.
jit
.
script_method
def
forward
(
self
,
specgram
):
r
"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
specgram (torch.Tensor): STFT of size (c, f, t)
Returns:
torch.Tensor: STFT after changing scale of size (c, f, t)
"""
return
F
.
spectrogram_to_DB
(
specgram
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
class
MelScale
(
torch
.
jit
.
ScriptModule
):
r
"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
...
...
@@ -102,44 +154,64 @@ class MelScale(torch.jit.ScriptModule):
return
mel_specgram
class
SpectrogramToDB
(
torch
.
jit
.
ScriptModule
):
r
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
class
MelSpectrogram
(
torch
.
jit
.
ScriptModule
):
r
"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
stype (str): scale of input spectrogram ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: 'power')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80.
sample_rate (int): Sample rate of audio signal. (Default: 16000)
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
f_min (float): Minimum frequency. (Default: 0.)
f_max (float, optional): Maximum frequency. (Default: `None`)
pad (int): Two sided padding of signal. (Default: 0)
n_mels (int): Number of mel filterbanks. (Default: 128)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
Example:
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t)
"""
__constants__
=
[
'
multiplier'
,
'amin'
,
'ref_value'
,
'db_multiplier
'
]
__constants__
=
[
'
sample_rate'
,
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
,
'n_mels'
,
'f_min
'
]
def
__init__
(
self
,
stype
=
'power'
,
top_db
=
None
):
super
(
SpectrogramToDB
,
self
).
__init__
()
self
.
stype
=
torch
.
jit
.
Attribute
(
stype
,
str
)
if
top_db
is
not
None
and
top_db
<
0
:
raise
ValueError
(
'top_db must be positive value'
)
self
.
top_db
=
torch
.
jit
.
Attribute
(
top_db
,
Optional
[
float
])
self
.
multiplier
=
10.0
if
stype
==
'power'
else
20.0
self
.
amin
=
1e-10
self
.
ref_value
=
1.0
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
def
__init__
(
self
,
sample_rate
=
16000
,
n_fft
=
400
,
win_length
=
None
,
hop_length
=
None
,
f_min
=
0.
,
f_max
=
None
,
pad
=
0
,
n_mels
=
128
,
window_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop_length
=
hop_length
if
hop_length
is
not
None
else
self
.
win_length
//
2
self
.
pad
=
pad
self
.
n_mels
=
n_mels
# number of mel frequency bins
self
.
f_max
=
torch
.
jit
.
Attribute
(
f_max
,
Optional
[
float
])
self
.
f_min
=
f_min
self
.
spectrogram
=
Spectrogram
(
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
2
,
normalized
=
False
,
wkwargs
=
wkwargs
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
specgram
):
r
"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
def
forward
(
self
,
waveform
):
r
"""
Args:
specgra
m (torch.Tensor):
STFT
of size (c,
f, t
)
wavefor
m (torch.Tensor):
Tensor of audio
of size (c,
n
)
Returns:
torch.Tensor:
STFT after changing scale
of size (c,
f
, t)
torch.Tensor:
mel frequency spectrogram
of size (c,
`n_mels`
, t)
"""
return
F
.
spectrogram_to_DB
(
specgram
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
specgram
=
self
.
spectrogram
(
waveform
)
mel_specgram
=
self
.
mel_scale
(
specgram
)
return
mel_specgram
class
MFCC
(
torch
.
jit
.
ScriptModule
):
...
...
@@ -207,66 +279,6 @@ class MFCC(torch.jit.ScriptModule):
return
mfcc
class
MelSpectrogram
(
torch
.
jit
.
ScriptModule
):
r
"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sample_rate (int): Sample rate of audio signal. (Default: 16000)
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
f_min (float): Minimum frequency. (Default: 0.)
f_max (float, optional): Maximum frequency. (Default: `None`)
pad (int): Two sided padding of signal. (Default: 0)
n_mels (int): Number of mel filterbanks. (Default: 128)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
Example:
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t)
"""
__constants__
=
[
'sample_rate'
,
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
,
'n_mels'
,
'f_min'
]
def
__init__
(
self
,
sample_rate
=
16000
,
n_fft
=
400
,
win_length
=
None
,
hop_length
=
None
,
f_min
=
0.
,
f_max
=
None
,
pad
=
0
,
n_mels
=
128
,
window_fn
=
torch
.
hann_window
,
wkwargs
=
None
):
super
(
MelSpectrogram
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop_length
=
hop_length
if
hop_length
is
not
None
else
self
.
win_length
//
2
self
.
pad
=
pad
self
.
n_mels
=
n_mels
# number of mel frequency bins
self
.
f_max
=
torch
.
jit
.
Attribute
(
f_max
,
Optional
[
float
])
self
.
f_min
=
f_min
self
.
spectrogram
=
Spectrogram
(
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
,
pad
=
self
.
pad
,
window_fn
=
window_fn
,
power
=
2
,
normalized
=
False
,
wkwargs
=
wkwargs
)
self
.
mel_scale
=
MelScale
(
self
.
n_mels
,
self
.
sample_rate
,
self
.
f_min
,
self
.
f_max
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
waveform
):
r
"""
Args:
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
specgram
=
self
.
spectrogram
(
waveform
)
mel_specgram
=
self
.
mel_scale
(
specgram
)
return
mel_specgram
class
MuLawEncoding
(
torch
.
jit
.
ScriptModule
):
r
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment