You need to sign in or sign up before continuing.
Unverified Commit 7ee1c46b authored by moto's avatar moto Committed by GitHub
Browse files

Add Kaldi Pitch feature (#1243)

parent 9e58e75c
...@@ -15,6 +15,10 @@ if(BUILD_TRANSDUCER) ...@@ -15,6 +15,10 @@ if(BUILD_TRANSDUCER)
list(APPEND LIBTORCHAUDIO_SOURCES transducer.cpp) list(APPEND LIBTORCHAUDIO_SOURCES transducer.cpp)
endif() endif()
if(BUILD_KALDI)
list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp)
endif()
################################################################################ ################################################################################
# libtorchaudio.so # libtorchaudio.so
################################################################################ ################################################################################
......
#include <torch/script.h>
#include "feat/pitch-functions.h"
namespace torchaudio {
namespace kaldi {
namespace {
torch::Tensor denormalize(const torch::Tensor& t) {
auto ret = t;
auto pos = t > 0, neg = t < 0;
ret.index_put({pos}, t.index({pos}) * 32767);
ret.index_put({neg}, t.index({neg}) * 32768);
return ret;
}
torch::Tensor compute_kaldi_pitch(
const torch::Tensor& wave,
const ::kaldi::PitchExtractionOptions& opts) {
::kaldi::VectorBase<::kaldi::BaseFloat> input(wave);
::kaldi::Matrix<::kaldi::BaseFloat> output;
::kaldi::ComputeKaldiPitch(opts, input, &output);
return output.tensor_;
}
} // namespace
torch::Tensor ComputeKaldiPitch(
const torch::Tensor& wave,
double sample_frequency,
double frame_length,
double frame_shift,
double min_f0,
double max_f0,
double soft_min_f0,
double penalty_factor,
double lowpass_cutoff,
double resample_frequency,
double delta_pitch,
double nccf_ballast,
int64_t lowpass_filter_width,
int64_t upsample_filter_width,
int64_t max_frames_latency,
int64_t frames_per_chunk,
bool simulate_first_pass_online,
int64_t recompute_frame,
bool snip_edges) {
TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional.");
TORCH_CHECK(wave.device().is_cpu(), "Input tensor must be on CPU.");
TORCH_CHECK(
wave.dtype() == torch::kFloat32, "Input tensor must be float32 type.");
::kaldi::PitchExtractionOptions opts;
opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency);
opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift);
opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length);
opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0);
opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0);
opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0);
opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor);
opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff);
opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency);
opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch);
opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width);
opts.upsample_filter_width =
static_cast<::kaldi::int32>(upsample_filter_width);
opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency);
opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk);
opts.simulate_first_pass_online = simulate_first_pass_online;
opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame);
opts.snip_edges = snip_edges;
// Kaldi's float type expects value range of int16 expressed as float
torch::Tensor wave_ = denormalize(wave);
auto batch_size = wave_.size(0);
std::vector<torch::Tensor> results(batch_size);
at::parallel_for(0, batch_size, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; ++i) {
results[i] = compute_kaldi_pitch(wave_.index({i}), opts);
}
});
return torch::stack(results, 0);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::kaldi_ComputeKaldiPitch",
&torchaudio::kaldi::ComputeKaldiPitch);
}
} // namespace kaldi
} // namespace torchaudio
...@@ -3,6 +3,7 @@ from .functional import ( ...@@ -3,6 +3,7 @@ from .functional import (
angle, angle,
complex_norm, complex_norm,
compute_deltas, compute_deltas,
compute_kaldi_pitch,
create_dct, create_dct,
create_fb_matrix, create_fb_matrix,
DB_to_amplitude, DB_to_amplitude,
...@@ -47,6 +48,7 @@ __all__ = [ ...@@ -47,6 +48,7 @@ __all__ = [
'angle', 'angle',
'complex_norm', 'complex_norm',
'compute_deltas', 'compute_deltas',
'compute_kaldi_pitch',
'create_dct', 'create_dct',
'create_fb_matrix', 'create_fb_matrix',
'DB_to_amplitude', 'DB_to_amplitude',
......
...@@ -13,6 +13,7 @@ __all__ = [ ...@@ -13,6 +13,7 @@ __all__ = [
"amplitude_to_DB", "amplitude_to_DB",
"DB_to_amplitude", "DB_to_amplitude",
"compute_deltas", "compute_deltas",
"compute_kaldi_pitch",
"create_fb_matrix", "create_fb_matrix",
"create_dct", "create_dct",
"compute_deltas", "compute_deltas",
...@@ -991,3 +992,105 @@ def spectral_centroid( ...@@ -991,3 +992,105 @@ def spectral_centroid(
device=specgram.device).reshape((-1, 1)) device=specgram.device).reshape((-1, 1))
freq_dim = -2 freq_dim = -2
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_f0: float = 50,
max_f0: float = 400,
soft_min_f0: float = 10.0,
penalty_factor: float = 0.1,
lowpass_cutoff: float = 1000,
resample_frequency: float = 4000,
delta_pitch: float = 0.005,
nccf_ballast: float = 7000,
lowpass_filter_width: int = 1,
upsample_filter_width: int = 5,
max_frames_latency: int = 0,
frames_per_chunk: int = 0,
simulate_first_pass_online: bool = False,
recompute_frame: int = 500,
snip_edges: bool = True,
) -> torch.Tensor:
"""Extract pitch based on method described in [1].
This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.
Args:
waveform (Tensor):
The input waveform of shape `(..., time)`.
sample_rate (float):
Sample rate of `waveform`.
frame_length (float, optional):
Frame length in milliseconds.
frame_shift (float, optional):
Frame shift in milliseconds.
min_f0 (float, optional):
Minimum F0 to search for (Hz)
max_f0 (float, optional):
Maximum F0 to search for (Hz)
soft_min_f0 (float, optional):
Minimum f0, applied in soft way, must not exceed min-f0
penalty_factor (float, optional):
Cost factor for FO change.
lowpass_cutoff (float, optional):
Cutoff frequency for LowPass filter (Hz)
resample_frequency (float, optional):
Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
delta_pitch( float, optional):
Smallest relative change in pitch that our algorithm measures.
nccf_ballast (float, optional):
Increasing this factor reduces NCCF for quiet frames
lowpass_filter_width (int, optional):
Integer that determines filter width of lowpass filter, more gives sharper filter.
upsample_filter_width (int, optional):
Integer that determines filter width when upsampling NCCF.
max_frames_latency (int, optional):
Maximum number of frames of latency that we allow pitch tracking to introduce into
the feature processing (affects output only if ``frames_per_chunk > 0`` and
``simulate_first_pass_online=True``)
frames_per_chunk (int, optional):
The number of frames used for energy normalization.
simulate_first_pass_online (bool, optional):
If true, the function will output features that correspond to what an online decoder
would see in the first pass of decoding -- not the final version of the features,
which is the default.
Relevant if ``frames_per_chunk > 0``.
recompute_frame (int, optional):
Only relevant for compatibility with online pitch extraction.
A non-critical parameter; the frame at which we recompute some of the forward pointers,
after revising our estimate of the signal energy.
Relevant if ``frames_per_chunk > 0``.
snip_edges (bool, optional):
If this is set to false, the incomplete frames near the ending edge won't be snipped,
so that the number of frames is the file size divided by the frame-shift.
This makes different types of features give the same number of frames.
Returns:
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
Reference:
- A pitch extraction algorithm tuned for automatic speech recognition
P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur
2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
"""
shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1])
result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
waveform, sample_rate, frame_length, frame_shift,
min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
resample_frequency, delta_pitch, nccf_ballast,
lowpass_filter_width, upsample_filter_width, max_frames_latency,
frames_per_chunk, simulate_first_pass_online, recompute_frame,
snip_edges,
)
result = result.reshape(shape[:-1] + result.shape[-2:])
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment