Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7ee1c46b
Unverified
Commit
7ee1c46b
authored
Feb 09, 2021
by
moto
Committed by
GitHub
Feb 09, 2021
Browse files
Add Kaldi Pitch feature (#1243)
parent
9e58e75c
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
202 additions
and
0 deletions
+202
-0
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+4
-0
torchaudio/csrc/kaldi.cpp
torchaudio/csrc/kaldi.cpp
+93
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+2
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+103
-0
No files found.
torchaudio/csrc/CMakeLists.txt
View file @
7ee1c46b
...
@@ -15,6 +15,10 @@ if(BUILD_TRANSDUCER)
...
@@ -15,6 +15,10 @@ if(BUILD_TRANSDUCER)
list
(
APPEND LIBTORCHAUDIO_SOURCES transducer.cpp
)
list
(
APPEND LIBTORCHAUDIO_SOURCES transducer.cpp
)
endif
()
endif
()
if
(
BUILD_KALDI
)
list
(
APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp
)
endif
()
################################################################################
################################################################################
# libtorchaudio.so
# libtorchaudio.so
################################################################################
################################################################################
...
...
torchaudio/csrc/kaldi.cpp
0 → 100644
View file @
7ee1c46b
#include <torch/script.h>
#include "feat/pitch-functions.h"
namespace
torchaudio
{
namespace
kaldi
{
namespace
{
torch
::
Tensor
denormalize
(
const
torch
::
Tensor
&
t
)
{
auto
ret
=
t
;
auto
pos
=
t
>
0
,
neg
=
t
<
0
;
ret
.
index_put
({
pos
},
t
.
index
({
pos
})
*
32767
);
ret
.
index_put
({
neg
},
t
.
index
({
neg
})
*
32768
);
return
ret
;
}
torch
::
Tensor
compute_kaldi_pitch
(
const
torch
::
Tensor
&
wave
,
const
::
kaldi
::
PitchExtractionOptions
&
opts
)
{
::
kaldi
::
VectorBase
<::
kaldi
::
BaseFloat
>
input
(
wave
);
::
kaldi
::
Matrix
<::
kaldi
::
BaseFloat
>
output
;
::
kaldi
::
ComputeKaldiPitch
(
opts
,
input
,
&
output
);
return
output
.
tensor_
;
}
}
// namespace
torch
::
Tensor
ComputeKaldiPitch
(
const
torch
::
Tensor
&
wave
,
double
sample_frequency
,
double
frame_length
,
double
frame_shift
,
double
min_f0
,
double
max_f0
,
double
soft_min_f0
,
double
penalty_factor
,
double
lowpass_cutoff
,
double
resample_frequency
,
double
delta_pitch
,
double
nccf_ballast
,
int64_t
lowpass_filter_width
,
int64_t
upsample_filter_width
,
int64_t
max_frames_latency
,
int64_t
frames_per_chunk
,
bool
simulate_first_pass_online
,
int64_t
recompute_frame
,
bool
snip_edges
)
{
TORCH_CHECK
(
wave
.
ndimension
()
==
2
,
"Input tensor must be 2 dimentional."
);
TORCH_CHECK
(
wave
.
device
().
is_cpu
(),
"Input tensor must be on CPU."
);
TORCH_CHECK
(
wave
.
dtype
()
==
torch
::
kFloat32
,
"Input tensor must be float32 type."
);
::
kaldi
::
PitchExtractionOptions
opts
;
opts
.
samp_freq
=
static_cast
<::
kaldi
::
BaseFloat
>
(
sample_frequency
);
opts
.
frame_shift_ms
=
static_cast
<::
kaldi
::
BaseFloat
>
(
frame_shift
);
opts
.
frame_length_ms
=
static_cast
<::
kaldi
::
BaseFloat
>
(
frame_length
);
opts
.
min_f0
=
static_cast
<::
kaldi
::
BaseFloat
>
(
min_f0
);
opts
.
max_f0
=
static_cast
<::
kaldi
::
BaseFloat
>
(
max_f0
);
opts
.
soft_min_f0
=
static_cast
<::
kaldi
::
BaseFloat
>
(
soft_min_f0
);
opts
.
penalty_factor
=
static_cast
<::
kaldi
::
BaseFloat
>
(
penalty_factor
);
opts
.
lowpass_cutoff
=
static_cast
<::
kaldi
::
BaseFloat
>
(
lowpass_cutoff
);
opts
.
resample_freq
=
static_cast
<::
kaldi
::
BaseFloat
>
(
resample_frequency
);
opts
.
delta_pitch
=
static_cast
<::
kaldi
::
BaseFloat
>
(
delta_pitch
);
opts
.
lowpass_filter_width
=
static_cast
<::
kaldi
::
int32
>
(
lowpass_filter_width
);
opts
.
upsample_filter_width
=
static_cast
<::
kaldi
::
int32
>
(
upsample_filter_width
);
opts
.
max_frames_latency
=
static_cast
<::
kaldi
::
int32
>
(
max_frames_latency
);
opts
.
frames_per_chunk
=
static_cast
<::
kaldi
::
int32
>
(
frames_per_chunk
);
opts
.
simulate_first_pass_online
=
simulate_first_pass_online
;
opts
.
recompute_frame
=
static_cast
<::
kaldi
::
int32
>
(
recompute_frame
);
opts
.
snip_edges
=
snip_edges
;
// Kaldi's float type expects value range of int16 expressed as float
torch
::
Tensor
wave_
=
denormalize
(
wave
);
auto
batch_size
=
wave_
.
size
(
0
);
std
::
vector
<
torch
::
Tensor
>
results
(
batch_size
);
at
::
parallel_for
(
0
,
batch_size
,
1
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
auto
i
=
begin
;
i
<
end
;
++
i
)
{
results
[
i
]
=
compute_kaldi_pitch
(
wave_
.
index
({
i
}),
opts
);
}
});
return
torch
::
stack
(
results
,
0
);
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"torchaudio::kaldi_ComputeKaldiPitch"
,
&
torchaudio
::
kaldi
::
ComputeKaldiPitch
);
}
}
// namespace kaldi
}
// namespace torchaudio
torchaudio/functional/__init__.py
View file @
7ee1c46b
...
@@ -3,6 +3,7 @@ from .functional import (
...
@@ -3,6 +3,7 @@ from .functional import (
angle
,
angle
,
complex_norm
,
complex_norm
,
compute_deltas
,
compute_deltas
,
compute_kaldi_pitch
,
create_dct
,
create_dct
,
create_fb_matrix
,
create_fb_matrix
,
DB_to_amplitude
,
DB_to_amplitude
,
...
@@ -47,6 +48,7 @@ __all__ = [
...
@@ -47,6 +48,7 @@ __all__ = [
'angle'
,
'angle'
,
'complex_norm'
,
'complex_norm'
,
'compute_deltas'
,
'compute_deltas'
,
'compute_kaldi_pitch'
,
'create_dct'
,
'create_dct'
,
'create_fb_matrix'
,
'create_fb_matrix'
,
'DB_to_amplitude'
,
'DB_to_amplitude'
,
...
...
torchaudio/functional/functional.py
View file @
7ee1c46b
...
@@ -13,6 +13,7 @@ __all__ = [
...
@@ -13,6 +13,7 @@ __all__ = [
"amplitude_to_DB"
,
"amplitude_to_DB"
,
"DB_to_amplitude"
,
"DB_to_amplitude"
,
"compute_deltas"
,
"compute_deltas"
,
"compute_kaldi_pitch"
,
"create_fb_matrix"
,
"create_fb_matrix"
,
"create_dct"
,
"create_dct"
,
"compute_deltas"
,
"compute_deltas"
,
...
@@ -991,3 +992,105 @@ def spectral_centroid(
...
@@ -991,3 +992,105 @@ def spectral_centroid(
device
=
specgram
.
device
).
reshape
((
-
1
,
1
))
device
=
specgram
.
device
).
reshape
((
-
1
,
1
))
freq_dim
=
-
2
freq_dim
=
-
2
return
(
freqs
*
specgram
).
sum
(
dim
=
freq_dim
)
/
specgram
.
sum
(
dim
=
freq_dim
)
return
(
freqs
*
specgram
).
sum
(
dim
=
freq_dim
)
/
specgram
.
sum
(
dim
=
freq_dim
)
def
compute_kaldi_pitch
(
waveform
:
torch
.
Tensor
,
sample_rate
:
float
,
frame_length
:
float
=
25.0
,
frame_shift
:
float
=
10.0
,
min_f0
:
float
=
50
,
max_f0
:
float
=
400
,
soft_min_f0
:
float
=
10.0
,
penalty_factor
:
float
=
0.1
,
lowpass_cutoff
:
float
=
1000
,
resample_frequency
:
float
=
4000
,
delta_pitch
:
float
=
0.005
,
nccf_ballast
:
float
=
7000
,
lowpass_filter_width
:
int
=
1
,
upsample_filter_width
:
int
=
5
,
max_frames_latency
:
int
=
0
,
frames_per_chunk
:
int
=
0
,
simulate_first_pass_online
:
bool
=
False
,
recompute_frame
:
int
=
500
,
snip_edges
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Extract pitch based on method described in [1].
This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.
Args:
waveform (Tensor):
The input waveform of shape `(..., time)`.
sample_rate (float):
Sample rate of `waveform`.
frame_length (float, optional):
Frame length in milliseconds.
frame_shift (float, optional):
Frame shift in milliseconds.
min_f0 (float, optional):
Minimum F0 to search for (Hz)
max_f0 (float, optional):
Maximum F0 to search for (Hz)
soft_min_f0 (float, optional):
Minimum f0, applied in soft way, must not exceed min-f0
penalty_factor (float, optional):
Cost factor for FO change.
lowpass_cutoff (float, optional):
Cutoff frequency for LowPass filter (Hz)
resample_frequency (float, optional):
Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
delta_pitch( float, optional):
Smallest relative change in pitch that our algorithm measures.
nccf_ballast (float, optional):
Increasing this factor reduces NCCF for quiet frames
lowpass_filter_width (int, optional):
Integer that determines filter width of lowpass filter, more gives sharper filter.
upsample_filter_width (int, optional):
Integer that determines filter width when upsampling NCCF.
max_frames_latency (int, optional):
Maximum number of frames of latency that we allow pitch tracking to introduce into
the feature processing (affects output only if ``frames_per_chunk > 0`` and
``simulate_first_pass_online=True``)
frames_per_chunk (int, optional):
The number of frames used for energy normalization.
simulate_first_pass_online (bool, optional):
If true, the function will output features that correspond to what an online decoder
would see in the first pass of decoding -- not the final version of the features,
which is the default.
Relevant if ``frames_per_chunk > 0``.
recompute_frame (int, optional):
Only relevant for compatibility with online pitch extraction.
A non-critical parameter; the frame at which we recompute some of the forward pointers,
after revising our estimate of the signal energy.
Relevant if ``frames_per_chunk > 0``.
snip_edges (bool, optional):
If this is set to false, the incomplete frames near the ending edge won't be snipped,
so that the number of frames is the file size divided by the frame-shift.
This makes different types of features give the same number of frames.
Returns:
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
Reference:
- A pitch extraction algorithm tuned for automatic speech recognition
P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur
2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
"""
shape
=
waveform
.
shape
waveform
=
waveform
.
reshape
(
-
1
,
shape
[
-
1
])
result
=
torch
.
ops
.
torchaudio
.
kaldi_ComputeKaldiPitch
(
waveform
,
sample_rate
,
frame_length
,
frame_shift
,
min_f0
,
max_f0
,
soft_min_f0
,
penalty_factor
,
lowpass_cutoff
,
resample_frequency
,
delta_pitch
,
nccf_ballast
,
lowpass_filter_width
,
upsample_filter_width
,
max_frames_latency
,
frames_per_chunk
,
simulate_first_pass_online
,
recompute_frame
,
snip_edges
,
)
result
=
result
.
reshape
(
shape
[:
-
1
]
+
result
.
shape
[
-
2
:])
return
result
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment