Unverified Commit e43ee196 authored by moto's avatar moto Committed by GitHub
Browse files

Replace torchaudio.load in test with scipy func (#762)

parent 4b583eab
...@@ -299,8 +299,6 @@ class TestIstft(common_utils.TorchaudioTestCase): ...@@ -299,8 +299,6 @@ class TestIstft(common_utils.TorchaudioTestCase):
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
backend = 'default'
def test_pitch(self): def test_pitch(self):
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav") test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav") test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
...@@ -312,7 +310,7 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): ...@@ -312,7 +310,7 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
] ]
for filename, freq_ref in tests: for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename) waveform, sample_rate = common_utils.load_wav(filename)
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
......
...@@ -5,11 +5,16 @@ import kaldi_io ...@@ -5,11 +5,16 @@ import kaldi_io
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
from . import common_utils
from .common_utils import load_params
from parameterized import parameterized from parameterized import parameterized
from .common_utils import (
TestBaseMixin,
load_params,
skipIfNoExec,
get_asset_path,
load_wav
)
def _convert_args(**kwargs): def _convert_args(**kwargs):
args = [] args = []
...@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value): ...@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning return torch.from_numpy(result.copy()) # copy supresses some torch warning
class Kaldi(common_utils.TestBaseMixin): class Kaldi(TestBaseMixin):
backend = 'sox'
def assert_equal(self, output, *, expected, rtol=None, atol=None): def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@common_utils.skipIfNoExec('apply-cmvn-sliding') @skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = { kwargs = {
...@@ -67,33 +70,33 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -67,33 +70,33 @@ class Kaldi(common_utils.TestBaseMixin):
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(load_params('kaldi_test_fbank_args.json')) @parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@common_utils.skipIfNoExec('compute-fbank-feats') @skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs): def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats""" """fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_spectrogram_args.json')) @parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
@common_utils.skipIfNoExec('compute-spectrogram-feats') @skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs): def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats""" """spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs) result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_mfcc_args.json')) @parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
@common_utils.skipIfNoExec('compute-mfcc-feats') @skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs): def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats""" """mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
......
...@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for functions in `transforms` module.""" """Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
common_utils.set_audio_backend('default') common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('sinewave.wav') path = common_utils.get_asset_path('sinewave.wav')
sound, sample_rate = common_utils.load_wav(path)
sound_librosa = sound.cpu().numpy().squeeze() # (64000) sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram # test core spectrogram
...@@ -300,9 +301,9 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -300,9 +301,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
hop_length = n_fft // 4 hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one. # Prepare mel spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default') path = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
sound, sample_rate = _load_audio_asset( sound, sample_rate = common_utils.load_wav(path)
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14) sound = sound[:, 2**10:2**10 + 2**14]
sound = sound.mean(dim=0, keepdim=True) sound = sound.mean(dim=0, keepdim=True)
spec_orig = F.spectrogram( spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
......
...@@ -45,7 +45,7 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -45,7 +45,7 @@ class Tester(common_utils.TorchaudioTestCase):
def test_AmplitudeToDB(self): def test_AmplitudeToDB(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, sample_rate = torchaudio.load(filepath) waveform = common_utils.load_wav(filepath)[0]
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.) power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
...@@ -115,7 +115,7 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -115,7 +115,7 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio # check on multi-channel audio
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
x_stereo, sr_stereo = torchaudio.load(filepath) # (2, 278756), 44100 x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_stereo.size(0) == 2)
...@@ -166,7 +166,7 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -166,7 +166,7 @@ class Tester(common_utils.TorchaudioTestCase):
def test_resample_size(self): def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav') input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path) waveform, sample_rate = common_utils.load_wav(input_path)
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2 downsample_rate = sample_rate // 2
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import unittest import unittest
import torch import torch
import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.transforms as T import torchaudio.transforms as T
...@@ -616,6 +615,5 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -616,6 +615,5 @@ class Transforms(common_utils.TestBaseMixin):
def test_Vad(self): def test_Vad(self):
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav") filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
common_utils.set_audio_backend('default') waveform, sample_rate = common_utils.load_wav(filepath)
waveform, sample_rate = torchaudio.load(filepath)
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform) self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment