models_test.py 7.51 KB
Newer Older
moto's avatar
moto committed
1
2
3
import itertools
from collections import namedtuple

Tomás Osório's avatar
Tomás Osório committed
4
import torch
moto's avatar
moto committed
5
from parameterized import parameterized
discort's avatar
discort committed
6
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
7
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
8
from torchaudio_unittest import common_utils
moto's avatar
moto committed
9
from torchaudio_unittest.common_utils import torch_script
Tomás Osório's avatar
Tomás Osório committed
10

11
12

class TestWav2Letter(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
13
14
15
16
17
18
19
    def test_waveform(self):
        batch_size = 2
        num_features = 1
        num_classes = 40
        input_length = 320

        model = Wav2Letter(num_classes=num_classes, num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
20
21
22
23
24
25

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)

jimchen90's avatar
jimchen90 committed
26
27
28
29
30
31
32
    def test_mfcc(self):
        batch_size = 2
        num_features = 13
        num_classes = 40
        input_length = 2

        model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
Tomás Osório's avatar
Tomás Osório committed
33
34
35
36
37

        x = torch.rand(batch_size, num_features, input_length)
        out = model(x)

        assert out.size() == (batch_size, num_classes, 2)
jimchen90's avatar
jimchen90 committed
38
39


40
class TestMelResNet(common_utils.TorchaudioTestCase):
jimchen90's avatar
jimchen90 committed
41
    def test_waveform(self):
42
        """Validate the output dimensions of a MelResNet block."""
jimchen90's avatar
jimchen90 committed
43

jimchen90's avatar
jimchen90 committed
44
45
46
47
48
49
50
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 128
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5
jimchen90's avatar
jimchen90 committed
51

52
        model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
jimchen90's avatar
jimchen90 committed
53

jimchen90's avatar
jimchen90 committed
54
        x = torch.rand(n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
55
56
        out = model(x)

jimchen90's avatar
jimchen90 committed
57
        assert out.size() == (n_batch, n_output, n_time - kernel_size + 1)
jimchen90's avatar
jimchen90 committed
58
59
60
61


class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
    def test_waveform(self):
62
        """Validate the output dimensions of a UpsampleNetwork block."""
jimchen90's avatar
jimchen90 committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        upsample_scales = [5, 5, 8]
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 256
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale

77
        model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
jimchen90's avatar
jimchen90 committed
78
79
80
81
82
83

        x = torch.rand(n_batch, n_freq, n_time)
        out1, out2 = model(x)

        assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1))
        assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1))
jimchen90's avatar
jimchen90 committed
84
85
86
87


class TestWaveRNN(common_utils.TorchaudioTestCase):
    def test_waveform(self):
88
        """Validate the output dimensions of a WaveRNN model."""
jimchen90's avatar
jimchen90 committed
89
90
91
92

        upsample_scales = [5, 5, 8]
        n_rnn = 512
        n_fc = 512
93
        n_classes = 512
jimchen90's avatar
jimchen90 committed
94
95
96
97
98
99
100
101
102
        hop_length = 200
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 256
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5

103
104
105
        model = WaveRNN(
            upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
        )
jimchen90's avatar
jimchen90 committed
106
107
108
109
110

        x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
        mels = torch.rand(n_batch, 1, n_freq, n_time)
        out = model(x, mels)

111
        assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
moto's avatar
moto committed
112

113
    def test_infer_waveform(self):
114
        """Validate the output dimensions of a WaveRNN model's infer method."""
115
116

        upsample_scales = [5, 5, 8]
117
118
119
        n_rnn = 128
        n_fc = 128
        n_classes = 128
120
121
        hop_length = 200
        n_batch = 2
122
123
124
125
126
        n_time = 50
        n_freq = 25
        n_output = 64
        n_res_block = 2
        n_hidden = 32
127
128
        kernel_size = 5

129
130
131
        model = WaveRNN(
            upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
        )
132
133

        x = torch.rand(n_batch, n_freq, n_time)
134
135
        lengths = torch.tensor([n_time, n_time // 2])
        out, waveform_lengths = model.infer(x, lengths)
136

137
138
139
        assert out.size() == (n_batch, 1, hop_length * n_time)
        assert waveform_lengths[0] == hop_length * n_time
        assert waveform_lengths[1] == hop_length * n_time // 2
140

moto's avatar
moto committed
141
142
143
144
    def test_torchscript_infer(self):
        """Scripted model outputs the same as eager mode"""

        upsample_scales = [5, 5, 8]
145
146
147
        n_rnn = 128
        n_fc = 128
        n_classes = 128
moto's avatar
moto committed
148
149
        hop_length = 200
        n_batch = 2
150
151
152
153
154
        n_time = 50
        n_freq = 25
        n_output = 64
        n_res_block = 2
        n_hidden = 32
moto's avatar
moto committed
155
156
        kernel_size = 5

157
158
159
        model = WaveRNN(
            upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
        )
moto's avatar
moto committed
160
161
162
163
164
165
166
167
        model.eval()
        x = torch.rand(n_batch, n_freq, n_time)
        torch.random.manual_seed(0)
        out_eager = model.infer(x)
        torch.random.manual_seed(0)
        out_script = torch_script(model).infer(x)
        self.assertEqual(out_eager, out_script)

moto's avatar
moto committed
168
169

_ConvTasNetParams = namedtuple(
170
    "_ConvTasNetParams",
moto's avatar
moto committed
171
    [
172
173
174
175
176
177
178
179
        "enc_num_feats",
        "enc_kernel_size",
        "msk_num_feats",
        "msk_num_hidden_feats",
        "msk_kernel_size",
        "msk_num_layers",
        "msk_num_stacks",
    ],
moto's avatar
moto committed
180
181
182
183
)


class TestConvTasNet(common_utils.TorchaudioTestCase):
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    @parameterized.expand(
        list(
            itertools.product(
                [2, 3],
                [
                    _ConvTasNetParams(128, 40, 128, 256, 3, 7, 2),
                    _ConvTasNetParams(256, 40, 128, 256, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 256, 256, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
                    _ConvTasNetParams(512, 40, 128, 512, 3, 6, 4),
                    _ConvTasNetParams(512, 40, 128, 512, 3, 4, 6),
                    _ConvTasNetParams(512, 40, 128, 512, 3, 8, 3),
                    _ConvTasNetParams(512, 32, 128, 512, 3, 8, 3),
                    _ConvTasNetParams(512, 16, 128, 512, 3, 8, 3),
                ],
            )
        )
    )
moto's avatar
moto committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def test_paper_configuration(self, num_sources, model_params):
        """ConvTasNet model works on the valid configurations in the paper"""
        batch_size = 32
        num_frames = 8000

        model = ConvTasNet(
            num_sources=num_sources,
            enc_kernel_size=model_params.enc_kernel_size,
            enc_num_feats=model_params.enc_num_feats,
            msk_kernel_size=model_params.msk_kernel_size,
            msk_num_feats=model_params.msk_num_feats,
            msk_num_hidden_feats=model_params.msk_num_hidden_feats,
            msk_num_layers=model_params.msk_num_layers,
            msk_num_stacks=model_params.msk_num_stacks,
        )
        tensor = torch.rand(batch_size, 1, num_frames)
        output = model(tensor)

        assert output.shape == (batch_size, num_sources, num_frames)
discort's avatar
discort committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241


class TestDeepSpeech(common_utils.TorchaudioTestCase):
    def test_deepspeech(self):
        n_batch = 2
        n_feature = 1
        n_channel = 1
        n_class = 40
        n_time = 320

        model = DeepSpeech(n_feature=n_feature, n_class=n_class)

        x = torch.rand(n_batch, n_channel, n_time, n_feature)
        out = model(x)

        assert out.size() == (n_batch, n_time, n_class)