test_case_utils.py 2.07 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import shutil
import os.path
import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available

from .backend_utils import set_audio_backend


class TempDirMixin:
    """Mixin to provide easy access to temp dir"""
    temp_dir_ = None

18
19
    @classmethod
    def get_base_temp_dir(cls):
moto's avatar
moto committed
20
21
22
23
        # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
        # this is handy for debugging.
        key = 'TORCHAUDIO_TEST_TEMP_DIR'
        if key in os.environ:
moto's avatar
moto committed
24
            return os.environ[key]
25
26
27
        if cls.temp_dir_ is None:
            cls.temp_dir_ = tempfile.TemporaryDirectory()
        return cls.temp_dir_.name
moto's avatar
moto committed
28
29
30
31

    @classmethod
    def tearDownClass(cls):
        super().tearDownClass()
moto's avatar
moto committed
32
        if cls.temp_dir_ is not None:
moto's avatar
moto committed
33
            cls.temp_dir_.cleanup()
moto's avatar
moto committed
34
            cls.temp_dir_ = None
moto's avatar
moto committed
35
36

    def get_temp_path(self, *paths):
37
        temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
moto's avatar
moto committed
38
        path = os.path.join(temp_dir, *paths)
moto's avatar
moto committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        os.makedirs(os.path.dirname(path), exist_ok=True)
        return path


class TestBaseMixin:
    """Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
    dtype = None
    device = None
    backend = None

    def setUp(self):
        super().setUp()
        set_audio_backend(self.backend)


class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
    pass


def skipIfNoExec(cmd):
    return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')


def skipIfNoModule(module, display_name=None):
    display_name = display_name or module
    return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf(
    'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')