You need to sign in or sign up before continuing.
Unverified Commit c92392fc authored by lawlict's avatar lawlict Committed by GitHub
Browse files

Improve the speed of kaldi.fbank with fused operator (#947)


Co-authored-by: default avatarlinqj3 <linqj3@lenovo.com>
parent 2d879132
...@@ -573,7 +573,7 @@ def fbank(waveform: Tensor, ...@@ -573,7 +573,7 @@ def fbank(waveform: Tensor,
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
power_spectrum = fft.pow(2).sum(2).unsqueeze(1) # size (m, 1, padded_window_size // 2 + 1) power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1)
if not use_power: if not use_power:
power_spectrum = power_spectrum.pow(0.5) power_spectrum = power_spectrum.pow(0.5)
...@@ -582,11 +582,11 @@ def fbank(waveform: Tensor, ...@@ -582,11 +582,11 @@ def fbank(waveform: Tensor,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp) low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype) mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1) # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = (power_spectrum * mel_energies).sum(dim=2) mel_energies = torch.mm(power_spectrum, mel_energies.T)
if use_log_fbank: if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering) # avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment