test_models.py 1.39 KB
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
import torch
jimchen90's avatar
jimchen90 committed
2
from torchaudio.models import Wav2Letter, _MelResNet
Tomás Osório's avatar
Tomás Osório committed
3

4
from . import common_utils
Tomás Osório's avatar
Tomás Osório committed
5

6
7

class TestWav2Letter(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
8
9
10
11
12
13
14
15

    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
16
17
18
19
20
21

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
22
23
24
25
26
27
28
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
29
30
31
32
33

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)
jimchen90's avatar
jimchen90 committed
34
35


36
class TestMelResNet(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
37
38

    def test_waveform(self):
jimchen90's avatar
jimchen90 committed
39
40
        """Validate the output dimensions of a _MelResNet block.
        """
jimchen90's avatar
jimchen90 committed
41

jimchen90's avatar
jimchen90 committed
42
43
44
45
46
47
48
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 128
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5
jimchen90's avatar
jimchen90 committed
49

jimchen90's avatar
jimchen90 committed
50
        model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
jimchen90's avatar
jimchen90 committed
51

jimchen90's avatar
jimchen90 committed
52
        x = torch.rand(n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
53
54
        out = model(x)

jimchen90's avatar
jimchen90 committed
55
        assert out.size() == (n_batch, n_output, n_time - kernel_size + 1)