Unverified Commit ddb8577d authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Migrate kaldi spectrogram (#687)



* Migrate spectrogram

* Update spectrogram in kaldi.py to support device and dtype

* Remove failing tests
Signed-off-by: default avatarBhargav Kathivarapu <bhargavkathivarapu31@gmail.com>
parent b56a27b5
This diff is collapsed.
...@@ -86,6 +86,17 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -86,6 +86,17 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available')
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json'))) @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available') @unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available')
def test_mfcc(self, kwargs): def test_mfcc(self, kwargs):
......
...@@ -272,6 +272,9 @@ def spectrogram(waveform: Tensor, ...@@ -272,6 +272,9 @@ def spectrogram(waveform: Tensor,
Tensor: A spectrogram identical to what Kaldi would output. The shape is Tensor: A spectrogram identical to what Kaldi would output. The shape is
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
""" """
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
...@@ -287,7 +290,7 @@ def spectrogram(waveform: Tensor, ...@@ -287,7 +290,7 @@ def spectrogram(waveform: Tensor,
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.pow(2).sum(2), EPSILON).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment