test_models.py 1.35 KB
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
import torch
jimchen90's avatar
jimchen90 committed
2
from torchaudio.models import Wav2Letter, _MelResNet
Tomás Osório's avatar
Tomás Osório committed
3

4
from . import common_utils
Tomás Osório's avatar
Tomás Osório committed
5

6
7

class TestWav2Letter(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
8
9
10
11
12
13
14
15

    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
16
17
18
19
20
21

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
22
23
24
25
26
27
28
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
29
30
31
32
33

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)
jimchen90's avatar
jimchen90 committed
34
35


36
class TestMelResNet(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    def test_waveform(self):

        batch_size = 2
        num_features = 200
        input_dims = 100
        output_dims = 128
        res_blocks = 10
        hidden_dims = 128
        pad = 2

        model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad)

        x = torch.rand(batch_size, input_dims, num_features)
        out = model(x)

        assert out.size() == (batch_size, output_dims, num_features - pad * 2)