common.py 1.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import itertools
from unittest import skipIf

from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available


def name_func(func, _, params):
    return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'


def dtype2subtype(dtype):
    return {
        "float64": "DOUBLE",
        "float32": "FLOAT",
        "int32": "PCM_32",
        "int16": "PCM_16",
        "uint8": "PCM_U8",
        "int8": "PCM_S8",
    }[dtype]


def skipIfFormatNotSupported(fmt):
    fmts = []
    if is_module_available("soundfile"):
        import soundfile

        fmts = soundfile.available_formats()
29
        return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
30
31
32
33
34
    return skipIf(True, '"soundfile" not available.')


def parameterize(*params):
    return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


def fetch_wav_subtype(dtype, encoding, bits_per_sample):
    subtype = {
        (None, None): dtype2subtype(dtype),
        (None, 8): "PCM_U8",
        ('PCM_U', None): "PCM_U8",
        ('PCM_U', 8): "PCM_U8",
        ('PCM_S', None): "PCM_32",
        ('PCM_S', 16): "PCM_16",
        ('PCM_S', 32): "PCM_32",
        ('PCM_F', None): "FLOAT",
        ('PCM_F', 32): "FLOAT",
        ('PCM_F', 64): "DOUBLE",
        ('ULAW', None): "ULAW",
        ('ULAW', 8): "ULAW",
        ('ALAW', None): "ALAW",
        ('ALAW', 8): "ALAW",
    }.get((encoding, bits_per_sample))
    if subtype:
        return subtype
    raise ValueError(
        f"wav does not support ({encoding}, {bits_per_sample}).")