playback_test.py 2.17 KB
Newer Older
mayp777's avatar
UPDATE  
mayp777 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from unittest.mock import patch

import torch
from parameterized import parameterized
from torchaudio.io import play_audio, StreamWriter
from torchaudio_unittest.common_utils import get_sinusoid, skipIfNoAudioDevice, skipIfNoMacOS, TorchaudioTestCase


@skipIfNoAudioDevice
@skipIfNoMacOS
class PlaybackInterfaceTest(TorchaudioTestCase):
    @parameterized.expand([("uint8",), ("int16",), ("int32",), ("int64",), ("float32",), ("float64",)])
    @patch.object(StreamWriter, "write_audio_chunk")
    def test_playaudio(self, dtype, writeaudio_mock):
        """Test playaudio function.
        The patch object is used to check if the data is written
        to the output device stream, without playing the actual audio.
        """
        dtype = getattr(torch, dtype)
        sample_rate = 8000
        waveform = get_sinusoid(
            frequency=440,
            sample_rate=sample_rate,
            duration=1,  # seconds
            n_channels=1,
            dtype=dtype,
            device="cpu",
            channels_first=False,
        )

        play_audio(waveform, sample_rate=sample_rate)

        writeaudio_mock.assert_called()

    @parameterized.expand(
        [
            # Invalid number of dimensions (!= 2)
            ("int16", 1, "audiotoolbox"),
            ("int16", 3, "audiotoolbox"),
            # Invalid tensor type
            ("complex64", 2, "audiotoolbox"),
            # Invalid output device
            ("int16", 2, "audiotool"),
        ]
    )
    @patch.object(StreamWriter, "write_audio_chunk")
    def test_playaudio_invalid_options(self, dtype, ndim, device, writeaudio_mock):
        """Test playaudio function raises error with invalid options."""
        dtype = getattr(torch, dtype)
        sample_rate = 8000
        waveform = get_sinusoid(
            frequency=440,
            sample_rate=sample_rate,
            duration=1,  # seconds
            n_channels=1,
            dtype=dtype,
            device="cpu",
            channels_first=False,
        ).squeeze()

        for _ in range(ndim - 1):
            waveform = waveform.unsqueeze(-1)

        with self.assertRaises(ValueError):
            play_audio(waveform, sample_rate=sample_rate, device=device)