Unverified Commit 2416e5d0 authored by moto's avatar moto Committed by GitHub
Browse files

Clean up common_utils (#690)

parent ddb8577d
...@@ -7,9 +7,11 @@ import torch ...@@ -7,9 +7,11 @@ import torch
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA from torchaudio.common_utils import _check_module_exists
if IMPORT_LIBROSA: LIBROSA_AVAILABLE = _check_module_exists('librosa')
if LIBROSA_AVAILABLE:
import numpy as np import numpy as np
import librosa import librosa
import scipy import scipy
...@@ -19,7 +21,7 @@ import pytest ...@@ -19,7 +21,7 @@ import pytest
from . import common_utils from . import common_utils
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctional(TestCase): class TestFunctional(TestCase):
"""Test suite for functions in `functional` module.""" """Test suite for functions in `functional` module."""
def test_griffinlim(self): def test_griffinlim(self):
...@@ -115,12 +117,8 @@ class TestFunctional(TestCase): ...@@ -115,12 +117,8 @@ class TestFunctional(TestCase):
]) ])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256]) @pytest.mark.parametrize('hop_length', [256])
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
def test_phase_vocoder(complex_specgrams, rate, hop_length): def test_phase_vocoder(complex_specgrams, rate, hop_length):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa is not available')
# Due to cummulative sum, numerical error in using torch.float32 will # Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not # result in bottom right values of the stretched sectrogram to not
# match with librosa. # match with librosa.
...@@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs): ...@@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs):
return sound, sample_rate return sound, sample_rate
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestTransforms(TestCase): class TestTransforms(TestCase):
"""Test suite for functions in `transforms` module.""" """Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
......
import importlib.util import importlib.util
def _check_module_exists(name: str) -> bool: def _check_module_exists(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without** r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of `import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544). (see librosa/#747, torchvision/#544).
""" """
spec = importlib.util.find_spec(name) return all(importlib.util.find_spec(m) is not None for m in modules)
return spec is not None
IMPORT_NUMPY = _check_module_exists('numpy')
IMPORT_KALDI_IO = _check_module_exists('kaldi_io')
IMPORT_SCIPY = _check_module_exists('scipy')
IMPORT_LIBROSA = _check_module_exists('librosa')
# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python) # To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
# needs to be installed. This is a light wrapper around kaldi_io that returns # needs to be installed. This is a light wrapper around kaldi_io that returns
# torch.Tensors. # torch.Tensors.
from typing import Any, Callable, Iterable, Tuple, Union from typing import Any, Callable, Iterable, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio.common_utils import IMPORT_KALDI_IO, IMPORT_NUMPY from torchaudio.common_utils import _check_module_exists
if IMPORT_NUMPY: _KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy')
import numpy as np
if IMPORT_KALDI_IO: if _KALDI_IO_AVAILABLE:
import numpy as np
import kaldi_io import kaldi_io
...@@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any, ...@@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
Returns: Returns:
Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
""" """
if not IMPORT_KALDI_IO: if not _KALDI_IO_AVAILABLE:
raise ImportError('Could not import kaldi_io. Did you install it?') raise ImportError('Could not import kaldi_io. Did you install it?')
for key, np_arr in fn(file_or_fd): for key, np_arr in fn(file_or_fd):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment