Unverified Commit e43ee196 authored by moto's avatar moto Committed by GitHub
Browse files

Replace torchaudio.load in test with scipy func (#762)

parent 4b583eab
......@@ -299,8 +299,6 @@ class TestIstft(common_utils.TorchaudioTestCase):
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
backend = 'default'
def test_pitch(self):
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
......@@ -312,7 +310,7 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
]
for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename)
waveform, sample_rate = common_utils.load_wav(filename)
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
......
......@@ -5,11 +5,16 @@ import kaldi_io
import torch
import torchaudio.functional as F
import torchaudio.compliance.kaldi
from . import common_utils
from .common_utils import load_params
from parameterized import parameterized
from .common_utils import (
TestBaseMixin,
load_params,
skipIfNoExec,
get_asset_path,
load_wav
)
def _convert_args(**kwargs):
args = []
......@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning
class Kaldi(common_utils.TestBaseMixin):
backend = 'sox'
class Kaldi(TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@common_utils.skipIfNoExec('apply-cmvn-sliding')
@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
......@@ -67,33 +70,33 @@ class Kaldi(common_utils.TestBaseMixin):
self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@common_utils.skipIfNoExec('compute-fbank-feats')
@skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
@common_utils.skipIfNoExec('compute-spectrogram-feats')
@skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
@common_utils.skipIfNoExec('compute-mfcc-feats')
@skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
......
......@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('sinewave.wav')
path = common_utils.get_asset_path('sinewave.wav')
sound, sample_rate = common_utils.load_wav(path)
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
......@@ -300,9 +301,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
path = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
sound, sample_rate = common_utils.load_wav(path)
sound = sound[:, 2**10:2**10 + 2**14]
sound = sound.mean(dim=0, keepdim=True)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
......
......@@ -45,7 +45,7 @@ class Tester(common_utils.TorchaudioTestCase):
def test_AmplitudeToDB(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, sample_rate = torchaudio.load(filepath)
waveform = common_utils.load_wav(filepath)[0]
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
......@@ -115,7 +115,7 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
x_stereo, sr_stereo = torchaudio.load(filepath) # (2, 278756), 44100
x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
......@@ -166,7 +166,7 @@ class Tester(common_utils.TorchaudioTestCase):
def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path)
waveform, sample_rate = common_utils.load_wav(input_path)
upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2
......
......@@ -2,7 +2,6 @@
import unittest
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
......@@ -616,6 +615,5 @@ class Transforms(common_utils.TestBaseMixin):
def test_Vad(self):
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
common_utils.set_audio_backend('default')
waveform, sample_rate = torchaudio.load(filepath)
waveform, sample_rate = common_utils.load_wav(filepath)
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment