test_models.py 804 Bytes
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
2
3
4
5
import torch
from torchaudio.models import Wav2Letter


class TestWav2Letter:
jimchen90's avatar
jimchen90 committed
6
7
8
9
10
11
12
13

    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
14
15
16
17
18
19

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
20
21
22
23
24
25
26
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
27
28
29
30
31

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)