test_models.py 1.25 KB
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
import torch
jimchen90's avatar
jimchen90 committed
2
from torchaudio.models import Wav2Letter, _MelResNet
Tomás Osório's avatar
Tomás Osório committed
3
4
5


class TestWav2Letter:
jimchen90's avatar
jimchen90 committed
6
7
8
9
10
11
12
13

    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
14
15
16
17
18
19

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
20
21
22
23
24
25
26
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
27
28
29
30
31

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)
jimchen90's avatar
jimchen90 committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


class TestMelResNet:

    def test_waveform(self):

        batch_size = 2
        num_features = 200
        input_dims = 100
        output_dims = 128
        res_blocks = 10
        hidden_dims = 128
        pad = 2

        model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad)

        x = torch.rand(batch_size, input_dims, num_features)
        out = model(x)

        assert out.size() == (batch_size, output_dims, num_features - pad * 2)