Commit a450cf81 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Kaldi MFCC (#228)

parent fd9684c8
...@@ -2,7 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -2,7 +2,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import random import random
import torchaudio import torchaudio
TEST_PREFIX = ['fbank', 'spec', 'resample'] TEST_PREFIX = ['spec', 'fbank', 'mfcc', 'resample']
def generate_rand_boolean(): def generate_rand_boolean():
......
...@@ -210,6 +210,40 @@ class Test_Kaldi(unittest.TestCase): ...@@ -210,6 +210,40 @@ class Test_Kaldi(unittest.TestCase):
self._compliance_test_helper(self.test_filepath, 'fbank', 97, 22, get_output_fn, atol=1e-3, rtol=1e-1) self._compliance_test_helper(self.test_filepath, 'fbank', 97, 22, get_output_fn, atol=1e-3, rtol=1e-1)
def test_mfcc(self):
def get_output_fn(sound, args):
output = kaldi.mfcc(
sound,
blackman_coeff=args[1],
dither=0.0,
energy_floor=args[2],
frame_length=args[3],
frame_shift=args[4],
high_freq=args[5],
htk_compat=args[6],
low_freq=args[7],
num_mel_bins=args[8],
preemphasis_coefficient=args[9],
raw_energy=args[10],
remove_dc_offset=args[11],
round_to_power_of_two=args[12],
snip_edges=args[13],
subtract_mean=args[14],
use_energy=args[15],
num_ceps=args[16],
cepstral_lifter=args[17],
vtln_high=args[18],
vtln_low=args[19],
vtln_warp=args[20],
window_type=args[21])
return output
self._compliance_test_helper(self.test_filepath, 'mfcc', 145, 22, get_output_fn, atol=1e-3)
def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
def test_resample_waveform(self): def test_resample_waveform(self):
def get_output_fn(sound, args): def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2]) output = kaldi.resample_waveform(sound, args[1], args[2])
......
...@@ -3,16 +3,17 @@ import math ...@@ -3,16 +3,17 @@ import math
import fractions import fractions
import random import random
import torch import torch
import torchaudio
__all__ = [ __all__ = [
'fbank',
'get_mel_banks', 'get_mel_banks',
'inverse_mel_scale', 'inverse_mel_scale',
'inverse_mel_scale_scalar', 'inverse_mel_scale_scalar',
'mel_scale', 'mel_scale',
'mel_scale_scalar', 'mel_scale_scalar',
'spectrogram', 'spectrogram',
'fbank',
'mfcc',
'vtln_warp_freq', 'vtln_warp_freq',
'vtln_warp_mel_freq', 'vtln_warp_mel_freq',
'resample_waveform', 'resample_waveform',
...@@ -117,7 +118,9 @@ def _get_waveform_and_window_properties(waveform, channel, sample_frequency, fra ...@@ -117,7 +118,9 @@ def _get_waveform_and_window_properties(waveform, channel, sample_frequency, fra
frame_length, round_to_power_of_two, preemphasis_coefficient): frame_length, round_to_power_of_two, preemphasis_coefficient):
r"""Gets the waveform and window properties r"""Gets the waveform and window properties
""" """
waveform = waveform[max(channel, 0), :] # size (n) channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
...@@ -182,6 +185,15 @@ def _get_window(waveform, padded_window_size, window_size, window_shift, window_ ...@@ -182,6 +185,15 @@ def _get_window(waveform, padded_window_size, window_size, window_shift, window_
return strided_input, signal_log_energy return strided_input, signal_log_energy
def _subtract_column_mean(tensor, subtract_mean):
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram( def spectrogram(
waveform, blackman_coeff=0.42, channel=-1, dither=1.0, energy_floor=0.0, waveform, blackman_coeff=0.42, channel=-1, dither=1.0, energy_floor=0.0,
frame_length=25.0, frame_shift=10.0, min_duration=0.0, frame_length=25.0, frame_shift=10.0, min_duration=0.0,
...@@ -239,10 +251,7 @@ def spectrogram( ...@@ -239,10 +251,7 @@ def spectrogram(
power_spectrum = torch.max(fft.pow(2).sum(2), EPSILON).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.pow(2).sum(2), EPSILON).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy power_spectrum[:, 0] = signal_log_energy
if subtract_mean: power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
col_means = torch.mean(power_spectrum, dim=0).unsqueeze(0) # size (1, padded_window_size // 2 + 1)
power_spectrum = power_spectrum - col_means
return power_spectrum return power_spectrum
...@@ -504,7 +513,7 @@ def fbank( ...@@ -504,7 +513,7 @@ def fbank(
# avoid log of zero (which should be prevented anyway by dithering) # avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, EPSILON).log() mel_energies = torch.max(mel_energies, EPSILON).log()
# if use_energy then add it as the first column for htk_compat == true else last column # if use_energy then add it as the last column for htk_compat == true else first column
if use_energy: if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1) # returns size (m, num_mel_bins + 1)
...@@ -513,13 +522,134 @@ def fbank( ...@@ -513,13 +522,134 @@ def fbank(
else: else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
if subtract_mean: mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
col_means = torch.mean(mel_energies, dim=0).unsqueeze(0) # size (1, num_mel_bins + use_energy)
mel_energies = mel_energies - col_means
return mel_energies return mel_energies
def _get_dct_matrix(num_ceps, num_mel_bins):
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho')
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps, cepstral_lifter):
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps, dtype=torch.get_default_dtype())
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
def mfcc(
waveform, blackman_coeff=0.42, cepstral_lifter=22.0, channel=-1, dither=1.0,
energy_floor=0.0, frame_length=25.0, frame_shift=10.0, high_freq=0.0, htk_compat=False,
low_freq=20.0, num_ceps=13, min_duration=0.0, num_mel_bins=23, preemphasis_coefficient=0.97,
raw_energy=True, remove_dc_offset=True, round_to_power_of_two=True,
sample_frequency=16000.0, snip_edges=True, subtract_mean=False, use_energy=False,
vtln_high=-500.0, vtln_low=100.0, vtln_warp=1.0, window_type=POVEY):
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
Args:
waveform (torch.Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``1.0``)
energy_floor (float): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``0.0``)
frame_length (float): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: ``0.0``)
htk_compat (bool): If true, put energy last. Warning: not sufficient to get HTK compatible features (need
to change other parameters). (Default: ``False``)
low_freq (float): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset: Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``)
Returns:
torch.Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins)
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel,
dither=dither, energy_floor=energy_floor, frame_length=frame_length,
frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat,
low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False,
use_energy=use_energy, use_log_fbank=True, use_power=True,
vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = torch.cat((feature, energy), dim=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width, def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
lowpass_cutoff, lowpass_filter_width): lowpass_cutoff, lowpass_filter_width):
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment