test_torchscript.py 5.5 KB
Newer Older
1
import itertools
2
from typing import Optional
3
4

import torch
moto's avatar
moto committed
5
import torchaudio
6
7
8
9
10
11
12
from parameterized import parameterized

from ..common_utils import (
    TempDirMixin,
    TorchaudioTestCase,
    skipIfNoExec,
    skipIfNoExtension,
moto's avatar
moto committed
13
    get_wav_data,
14
15
    save_wav,
    load_wav,
moto's avatar
moto committed
16
17
18
19
    sox_utils,
)
from .common import (
    name_func,
20
21
22
)


23
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
moto's avatar
moto committed
24
    return torchaudio.info(filepath)
25
26


moto's avatar
moto committed
27
def py_load_func(filepath: str, normalize: bool, channels_first: bool):
moto's avatar
moto committed
28
    return torchaudio.load(
moto's avatar
moto committed
29
30
31
        filepath, normalize=normalize, channels_first=channels_first)


32
33
34
35
36
37
38
def py_save_func(
        filepath: str,
        tensor: torch.Tensor,
        sample_rate: int,
        channels_first: bool = True,
        compression: Optional[float] = None,
):
moto's avatar
moto committed
39
    torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
40
41


42
43
44
@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
moto's avatar
moto committed
45
    """TorchScript-ability Test suite for `sox_io_backend`"""
moto's avatar
moto committed
46
47
    backend = 'sox_io'

48
49
50
51
    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
moto's avatar
moto committed
52
    )), name_func=name_func)
53
    def test_info_wav(self, dtype, sample_rate, num_channels):
moto's avatar
moto committed
54
        """`sox_io_backend.info` is torchscript-able and returns the same result"""
55
        audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
moto's avatar
moto committed
56
57
        data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
        save_wav(audio_path, data, sample_rate)
58

59
        script_path = self.get_temp_path('info_func.zip')
60
61
62
63
64
65
        torch.jit.script(py_info_func).save(script_path)
        ts_info_func = torch.jit.load(script_path)

        py_info = py_info_func(audio_path)
        ts_info = ts_info_func(audio_path)

66
67
68
        assert py_info.sample_rate == ts_info.sample_rate
        assert py_info.num_frames == ts_info.num_frames
        assert py_info.num_channels == ts_info.num_channels
moto's avatar
moto committed
69
70
71
72
73
74
75

    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
        [False, True],
        [False, True],
moto's avatar
moto committed
76
    )), name_func=name_func)
moto's avatar
moto committed
77
78
79
80
81
82
    def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
        """`sox_io_backend.load` is torchscript-able and returns the same result"""
        audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
        data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
        save_wav(audio_path, data, sample_rate)

83
        script_path = self.get_temp_path('load_func.zip')
moto's avatar
moto committed
84
85
86
87
88
89
90
91
92
93
        torch.jit.script(py_load_func).save(script_path)
        ts_load_func = torch.jit.load(script_path)

        py_data, py_sr = py_load_func(
            audio_path, normalize=normalize, channels_first=channels_first)
        ts_data, ts_sr = ts_load_func(
            audio_path, normalize=normalize, channels_first=channels_first)

        self.assertEqual(py_sr, ts_sr)
        self.assertEqual(py_data, ts_data)
94
95
96
97
98

    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
moto's avatar
moto committed
99
    )), name_func=name_func)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def test_save_wav(self, dtype, sample_rate, num_channels):
        script_path = self.get_temp_path('save_func.zip')
        torch.jit.script(py_save_func).save(script_path)
        ts_save_func = torch.jit.load(script_path)

        expected = get_wav_data(dtype, num_channels)
        py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
        ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')

        py_save_func(py_path, expected, sample_rate, True, None)
        ts_save_func(ts_path, expected, sample_rate, True, None)

        py_data, py_sr = load_wav(py_path)
        ts_data, ts_sr = load_wav(ts_path)

        self.assertEqual(sample_rate, py_sr)
        self.assertEqual(sample_rate, ts_sr)
        self.assertEqual(expected, py_data)
        self.assertEqual(expected, ts_data)

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        list(range(9)),
moto's avatar
moto committed
124
    )), name_func=name_func)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    def test_save_flac(self, sample_rate, num_channels, compression_level):
        script_path = self.get_temp_path('save_func.zip')
        torch.jit.script(py_save_func).save(script_path)
        ts_save_func = torch.jit.load(script_path)

        expected = get_wav_data('float32', num_channels)
        py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
        ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')

        py_save_func(py_path, expected, sample_rate, True, compression_level)
        ts_save_func(ts_path, expected, sample_rate, True, compression_level)

        # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
        py_path_wav = f'{py_path}.wav'
        ts_path_wav = f'{ts_path}.wav'
        sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
        sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)

        py_data, py_sr = load_wav(py_path_wav, normalize=True)
        ts_data, ts_sr = load_wav(ts_path_wav, normalize=True)

        self.assertEqual(sample_rate, py_sr)
        self.assertEqual(sample_rate, ts_sr)
        self.assertEqual(expected, py_data)
        self.assertEqual(expected, ts_data)