Unverified Commit 3a4f3569 authored by moto's avatar moto Committed by GitHub
Browse files

Make Kaldi fbank support cuda (#619)

* Make fbank support cuda

* Reduce rtol for kaldi

* fix test

* fix flake8
parent 00d38203
......@@ -37,7 +37,7 @@ def _run_kaldi(command, input_type, input_value):
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.numpy(), key=key)
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
......@@ -47,7 +47,7 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning
class TestFunctional:
class Kaldi(common_utils.TestBaseMixin):
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
......@@ -58,11 +58,11 @@ class TestFunctional:
'norm_vars': False,
}
tensor = torch.randn(40, 10)
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'ark', tensor)
torch.testing.assert_allclose(result, kaldi_result)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(self.dtype))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self):
......@@ -93,7 +93,11 @@ class TestFunctional:
}
wave_file = common_utils.get_asset_path('kaldi_file.wav')
result = torchaudio.compliance.kaldi.fbank(torchaudio.load_wav(wave_file)[0], **kwargs)
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
torch.testing.assert_allclose(result, kaldi_result)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype), rtol=1e-4, atol=1e-8)
common_utils.define_test_suites(globals(), [Kaldi])
......@@ -33,6 +33,10 @@ BLACKMAN = 'blackman'
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x
"""
......@@ -60,7 +64,7 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0))
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
......@@ -83,24 +87,27 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
def _feature_window_function(window_type: str,
window_size: int,
blackman_coeff: float) -> Tensor:
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size
"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False)
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46)
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85)
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size)
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function))
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
else:
raise Exception('Invalid window type ' + window_type)
......@@ -110,12 +117,12 @@ def _get_log_energy(strided_input: Tensor,
energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)
"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
else:
return torch.max(log_energy,
torch.tensor(math.log(energy_floor)))
return torch.max(
log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(waveform: Tensor,
......@@ -160,12 +167,15 @@ def _get_window(waveform: Tensor,
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = torch.max(EPSILON, torch.rand(strided_input.shape))
x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype))
rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither
......@@ -177,7 +187,7 @@ def _get_window(waveform: Tensor,
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
......@@ -187,7 +197,7 @@ def _get_window(waveform: Tensor,
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff).unsqueeze(0) # size (1, window_size)
window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
......@@ -198,7 +208,7 @@ def _get_window(waveform: Tensor,
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy
......@@ -541,12 +551,14 @@ def fbank(waveform: Tensor,
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
......@@ -563,6 +575,7 @@ def fbank(waveform: Tensor,
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0)
......@@ -571,7 +584,7 @@ def fbank(waveform: Tensor,
mel_energies = (power_spectrum * mel_energies).sum(dim=2)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, EPSILON).log()
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
......@@ -737,7 +750,9 @@ def _get_LR_indices_and_weights(orig_freq: float,
output_samples_in_unit: int,
window_width: float,
lowpass_cutoff: float,
lowpass_filter_width: int) -> Tuple[Tensor, Tensor]:
lowpass_filter_width: int,
device: torch.device,
dtype: int) -> Tuple[Tensor, Tensor]:
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
......@@ -785,7 +800,7 @@ def _get_LR_indices_and_weights(orig_freq: float,
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0., output_samples_in_unit) / new_freq
output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width
......@@ -795,7 +810,7 @@ def _get_LR_indices_and_weights(orig_freq: float,
max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width).unsqueeze(0)
j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)
......@@ -905,9 +920,9 @@ def resample_waveform(waveform: Tensor,
output_samples_in_unit = int(new_freq) // base_freq
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width)
weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly
first_indices, weights = _get_LR_indices_and_weights(
orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width, device, dtype)
assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment