"vscode:/vscode.git/clone" did not exist on "9407b45a0ca7cb3f2d695bc770176bc8b75dc96b"
test_models.py 3.25 KB
Newer Older
Tomás Osório's avatar
Tomás Osório committed
1
import torch
jimchen90's avatar
jimchen90 committed
2
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN
Tomás Osório's avatar
Tomás Osório committed
3

4
from . import common_utils
Tomás Osório's avatar
Tomás Osório committed
5

6
7

class TestWav2Letter(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
8
9
10
11
12
13
14
15

    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
16
17
18
19
20
21

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
22
23
24
25
26
27
28
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
29
30
31
32
33

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)
jimchen90's avatar
jimchen90 committed
34
35


36
class TestMelResNet(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
37
38

    def test_waveform(self):
jimchen90's avatar
jimchen90 committed
39
40
        """Validate the output dimensions of a _MelResNet block.
        """
jimchen90's avatar
jimchen90 committed
41

jimchen90's avatar
jimchen90 committed
42
43
44
45
46
47
48
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 128
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5
jimchen90's avatar
jimchen90 committed
49

jimchen90's avatar
jimchen90 committed
50
        model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
jimchen90's avatar
jimchen90 committed
51

jimchen90's avatar
jimchen90 committed
52
        x = torch.rand(n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
53
54
        out = model(x)

jimchen90's avatar
jimchen90 committed
55
        assert out.size() == (n_batch, n_output, n_time - kernel_size + 1)
jimchen90's avatar
jimchen90 committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


class TestUpsampleNetwork(common_utils.TorchaudioTestCase):

    def test_waveform(self):
        """Validate the output dimensions of a _UpsampleNetwork block.
        """

        upsample_scales = [5, 5, 8]
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 256
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale

77
78
79
80
81
82
        model = _UpsampleNetwork(upsample_scales,
                                 n_res_block,
                                 n_freq,
                                 n_hidden,
                                 n_output,
                                 kernel_size)
jimchen90's avatar
jimchen90 committed
83
84
85
86
87
88

        x = torch.rand(n_batch, n_freq, n_time)
        out1, out2 = model(x)

        assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1))
        assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1))
jimchen90's avatar
jimchen90 committed
89
90
91
92
93


class TestWaveRNN(common_utils.TorchaudioTestCase):

    def test_waveform(self):
94
        """Validate the output dimensions of a _WaveRNN model.
jimchen90's avatar
jimchen90 committed
95
96
97
98
99
        """

        upsample_scales = [5, 5, 8]
        n_rnn = 512
        n_fc = 512
100
        n_classes = 512
jimchen90's avatar
jimchen90 committed
101
102
103
104
105
106
107
108
109
        hop_length = 200
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 256
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5

110
111
        model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
                         n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
jimchen90's avatar
jimchen90 committed
112
113
114
115
116

        x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
        mels = torch.rand(n_batch, 1, n_freq, n_time)
        out = model(x, mels)

117
        assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)