test_models.py 1.01 KB
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pytest

import torch
from torchaudio.models import Wav2Letter


class TestWav2Letter:
    @pytest.mark.parametrize('batch_size', [2])
    @pytest.mark.parametrize('num_features', [1])
    @pytest.mark.parametrize('num_classes', [40])
    @pytest.mark.parametrize('input_length', [320])
    def test_waveform(self, batch_size, num_features, num_classes, input_length):
        model = Wav2Letter()

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

    @pytest.mark.parametrize('batch_size', [2])
    @pytest.mark.parametrize('num_features', [13])
    @pytest.mark.parametrize('num_classes', [40])
    @pytest.mark.parametrize('input_length', [2])
    def test_mfcc(self, batch_size, num_features, num_classes, input_length):
        model = Wav2Letter(input_type="mfcc", num_features=13)

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)