"docs/vscode:/vscode.git/clone" did not exist on "37665a0b3fd66661cf8d89e7d2cb99abd10a0f64"
Unverified Commit 3a4f3569 authored by moto's avatar moto Committed by GitHub
Browse files

Make Kaldi fbank support cuda (#619)

* Make fbank support cuda

* Reduce rtol for kaldi

* fix test

* fix flake8
parent 00d38203
...@@ -37,7 +37,7 @@ def _run_kaldi(command, input_type, input_value): ...@@ -37,7 +37,7 @@ def _run_kaldi(command, input_type, input_value):
key = 'foo' key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark': if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.numpy(), key=key) kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp': elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8')) process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else: else:
...@@ -47,7 +47,7 @@ def _run_kaldi(command, input_type, input_value): ...@@ -47,7 +47,7 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning return torch.from_numpy(result.copy()) # copy supresses some torch warning
class TestFunctional: class Kaldi(common_utils.TestBaseMixin):
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') @unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
...@@ -58,11 +58,11 @@ class TestFunctional: ...@@ -58,11 +58,11 @@ class TestFunctional:
'norm_vars': False, 'norm_vars': False,
} }
tensor = torch.randn(40, 10) tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs) result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'ark', tensor) kaldi_result = _run_kaldi(command, 'ark', tensor)
torch.testing.assert_allclose(result, kaldi_result) torch.testing.assert_allclose(result.cpu(), kaldi_result.to(self.dtype))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') @unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self): def test_fbank(self):
...@@ -93,7 +93,11 @@ class TestFunctional: ...@@ -93,7 +93,11 @@ class TestFunctional:
} }
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = common_utils.get_asset_path('kaldi_file.wav')
result = torchaudio.compliance.kaldi.fbank(torchaudio.load_wav(wave_file)[0], **kwargs) waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
torch.testing.assert_allclose(result, kaldi_result) torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype), rtol=1e-4, atol=1e-8)
common_utils.define_test_suites(globals(), [Kaldi])
...@@ -33,6 +33,10 @@ BLACKMAN = 'blackman' ...@@ -33,6 +33,10 @@ BLACKMAN = 'blackman'
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int: def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x r"""Returns the smallest power of 2 that is greater than x
""" """
...@@ -60,7 +64,7 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg ...@@ -60,7 +64,7 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
if snip_edges: if snip_edges:
if num_samples < window_size: if num_samples < window_size:
return torch.empty((0, 0)) return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else: else:
m = 1 + (num_samples - window_size) // window_shift m = 1 + (num_samples - window_size) // window_shift
else: else:
...@@ -83,24 +87,27 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg ...@@ -83,24 +87,27 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
def _feature_window_function(window_type: str, def _feature_window_function(window_type: str,
window_size: int, window_size: int,
blackman_coeff: float) -> Tensor: blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size r"""Returns a window function with the given type and size
""" """
if window_type == HANNING: if window_type == HANNING:
return torch.hann_window(window_size, periodic=False) return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING: elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46) return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY: elif window_type == POVEY:
# like hanning but goes to zero at edges # like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85) return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR: elif window_type == RECTANGULAR:
return torch.ones(window_size) return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN: elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1) a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size) window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients # can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) + return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)) (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
else: else:
raise Exception('Invalid window type ' + window_type) raise Exception('Invalid window type ' + window_type)
...@@ -110,12 +117,12 @@ def _get_log_energy(strided_input: Tensor, ...@@ -110,12 +117,12 @@ def _get_log_energy(strided_input: Tensor,
energy_floor: float) -> Tensor: energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*) r"""Returns the log energy of size (m) for a strided_input (m,*)
""" """
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0: if energy_floor == 0.0:
return log_energy return log_energy
else: return torch.max(
return torch.max(log_energy, log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
torch.tensor(math.log(energy_floor)))
def _get_waveform_and_window_properties(waveform: Tensor, def _get_waveform_and_window_properties(waveform: Tensor,
...@@ -160,12 +167,15 @@ def _get_window(waveform: Tensor, ...@@ -160,12 +167,15 @@ def _get_window(waveform: Tensor,
Returns: Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
""" """
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size) # size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0: if dither != 0.0:
# Returns a random number strictly between 0 and 1 # Returns a random number strictly between 0 and 1
x = torch.max(EPSILON, torch.rand(strided_input.shape)) x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype))
rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x) rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither strided_input = strided_input + rand_gauss * dither
...@@ -177,7 +187,7 @@ def _get_window(waveform: Tensor, ...@@ -177,7 +187,7 @@ def _get_window(waveform: Tensor,
if raw_energy: if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and # Compute the log energy of each row/frame before applying preemphasis and
# window function # window function
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m) signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0: if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
...@@ -187,7 +197,7 @@ def _get_window(waveform: Tensor, ...@@ -187,7 +197,7 @@ def _get_window(waveform: Tensor,
# Apply window_function to each row/frame # Apply window_function to each row/frame
window_function = _feature_window_function( window_function = _feature_window_function(
window_type, window_size, blackman_coeff).unsqueeze(0) # size (1, window_size) window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size) strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size) # Pad columns with zero until we reach size (m, padded_window_size)
...@@ -198,7 +208,7 @@ def _get_window(waveform: Tensor, ...@@ -198,7 +208,7 @@ def _get_window(waveform: Tensor,
# Compute energy after window function (not the raw one) # Compute energy after window function (not the raw one)
if not raw_energy: if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m) signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy return strided_input, signal_log_energy
...@@ -541,12 +551,14 @@ def fbank(waveform: Tensor, ...@@ -541,12 +551,14 @@ def fbank(waveform: Tensor,
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided where m is calculated in _get_strided
""" """
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency: if len(waveform) < min_duration * sample_frequency:
# signal is too short # signal is too short
return torch.empty(0) return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m) # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window( strided_input, signal_log_energy = _get_window(
...@@ -563,6 +575,7 @@ def fbank(waveform: Tensor, ...@@ -563,6 +575,7 @@ def fbank(waveform: Tensor,
# size (num_mel_bins, padded_window_size // 2) # size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp) low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1) # pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0)
...@@ -571,7 +584,7 @@ def fbank(waveform: Tensor, ...@@ -571,7 +584,7 @@ def fbank(waveform: Tensor,
mel_energies = (power_spectrum * mel_energies).sum(dim=2) mel_energies = (power_spectrum * mel_energies).sum(dim=2)
if use_log_fbank: if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering) # avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, EPSILON).log() mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column # if use_energy then add it as the last column for htk_compat == true else first column
if use_energy: if use_energy:
...@@ -737,7 +750,9 @@ def _get_LR_indices_and_weights(orig_freq: float, ...@@ -737,7 +750,9 @@ def _get_LR_indices_and_weights(orig_freq: float,
output_samples_in_unit: int, output_samples_in_unit: int,
window_width: float, window_width: float,
lowpass_cutoff: float, lowpass_cutoff: float,
lowpass_filter_width: int) -> Tuple[Tensor, Tensor]: lowpass_filter_width: int,
device: torch.device,
dtype: int) -> Tuple[Tensor, Tensor]:
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a that the output signal is at linearly spaced intervals (i.e the output signal has a
...@@ -785,7 +800,7 @@ def _get_LR_indices_and_weights(orig_freq: float, ...@@ -785,7 +800,7 @@ def _get_LR_indices_and_weights(orig_freq: float,
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)). which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
""" """
assert lowpass_cutoff < min(orig_freq, new_freq) / 2 assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0., output_samples_in_unit) / new_freq output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq
min_t = output_t - window_width min_t = output_t - window_width
max_t = output_t + window_width max_t = output_t + window_width
...@@ -795,7 +810,7 @@ def _get_LR_indices_and_weights(orig_freq: float, ...@@ -795,7 +810,7 @@ def _get_LR_indices_and_weights(orig_freq: float,
max_weight_width = num_indices.max() max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width) # create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width).unsqueeze(0) j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1) delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)
...@@ -905,9 +920,9 @@ def resample_waveform(waveform: Tensor, ...@@ -905,9 +920,9 @@ def resample_waveform(waveform: Tensor,
output_samples_in_unit = int(new_freq) // base_freq output_samples_in_unit = int(new_freq) // base_freq
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff) window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, first_indices, weights = _get_LR_indices_and_weights(
window_width, lowpass_cutoff, lowpass_filter_width) orig_freq, new_freq, output_samples_in_unit,
weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly window_width, lowpass_cutoff, lowpass_filter_width, device, dtype)
assert first_indices.dim() == 1 assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding # TODO figure a better way to do this. conv1d reaches every element i*stride + padding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment