Unverified Commit c92392fc authored by lawlict's avatar lawlict Committed by GitHub
Browse files

Improve the speed of kaldi.fbank with fused operator (#947)


Co-authored-by: default avatarlinqj3 <linqj3@lenovo.com>
parent 2d879132
......@@ -573,7 +573,7 @@ def fbank(waveform: Tensor,
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
power_spectrum = fft.pow(2).sum(2).unsqueeze(1) # size (m, 1, padded_window_size // 2 + 1)
power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1)
if not use_power:
power_spectrum = power_spectrum.pow(0.5)
......@@ -582,11 +582,11 @@ def fbank(waveform: Tensor,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = (power_spectrum * mel_energies).sum(dim=2)
mel_energies = torch.mm(power_spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment