model.py 15.2 KB
Newer Older
moto's avatar
moto committed
1
2
from typing import Optional, Tuple, List

3
import torch
moto's avatar
moto committed
4
5
6
7
8
9
10
from torch import Tensor
from torch.nn import Module

from . import components


class Wav2Vec2Model(Module):
11
    """Encoder model used in *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
moto's avatar
moto committed
12
13
14
15
16
17
18
19
20
21
22

    Note:
        To build the model, please use one of the factory functions.

    Args:
        feature_extractor (torch.nn.Module):
            Feature extractor that extracts feature vectors from raw audio Tensor.

        encoder (torch.nn.Module):
            Encoder that converts the audio features into the sequence of probability
            distribution (in negative log-likelihood) over labels.
23
24
25

        aux (torch.nn.Module or None, optional):
            Auxiliary module. If provided, the output from encoder is passed to this module.
moto's avatar
moto committed
26
27
28
29
30
    """
    def __init__(
            self,
            feature_extractor: Module,
            encoder: Module,
31
            aux: Optional[Module] = None,
moto's avatar
moto committed
32
33
34
35
    ):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.encoder = encoder
36
        self.aux = aux
moto's avatar
moto committed
37

38
    @torch.jit.export
moto's avatar
moto committed
39
40
41
42
    def extract_features(
            self,
            waveforms: Tensor,
            lengths: Optional[Tensor] = None,
43
44
            num_layers: Optional[int] = None,
    ) -> Tuple[List[Tensor], Optional[Tensor]]:
moto's avatar
moto committed
45
46
        """Extract feature vectors from raw waveforms

47
48
49
        This returns the list of outputs from the intermediate layers of
        transformer block in encoder.

moto's avatar
moto committed
50
51
        Args:
            waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
52
            lengths (Tensor or None, optional):
moto's avatar
moto committed
53
54
                Indicates the valid length of each audio sample in the batch.
                Shape: ``(batch, )``.
55
56
57
58
59
            num_layers (int or None, optional):
                If given, limit the number of intermediate layers to go through.
                Providing `1` will stop the computation after going through one
                intermediate layers. If not given, the outputs from all the
                intermediate layers are returned.
moto's avatar
moto committed
60
61

        Returns:
62
63
64
65
66
67
68
            List of Tensors and an optional Tensor:
            List of Tensors
                Features from requested layers.
                Each Tensor is of shape: ``(batch, frames, feature dimention)``
            Tensor or None
                If ``lengths`` argument was provided, a Tensor of shape ``(batch, )``
                is retuned. It indicates the valid length of each feature in the batch.
moto's avatar
moto committed
69
        """
70
71
72
        x, lengths = self.feature_extractor(waveforms, lengths)
        x = self.encoder.extract_features(x, lengths, num_layers)
        return x, lengths
moto's avatar
moto committed
73
74
75
76
77
78
79
80
81
82

    def forward(
            self,
            waveforms: Tensor,
            lengths: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Compute the sequence of probability distribution over labels.

        Args:
            waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
83
            lengths (Tensor or None, optional):
moto's avatar
moto committed
84
85
86
87
                Indicates the valid length of each audio sample in the batch.
                Shape: ``(batch, )``.

        Returns:
88
89
            Tensor and an optional Tensor:
            Tensor
moto's avatar
moto committed
90
91
                The sequences of probability distribution (in logit) over labels.
                Shape: ``(batch, frames, num labels)``.
92
93
94
            Tensor or None
                If ``lengths`` argument was provided, a Tensor of shape ``(batch, )``
                is retuned. It indicates the valid length of each feature in the batch.
moto's avatar
moto committed
95
96
        """
        x, lengths = self.feature_extractor(waveforms, lengths)
97
98
99
100
        x = self.encoder(x, lengths)
        if self.aux is not None:
            x = self.aux(x)
        return x, lengths
moto's avatar
moto committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118


def _get_model(
        extractor_mode: str,
        extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
        extractor_conv_bias: bool,
        encoder_embed_dim: int,
        encoder_projection_dropout: float,
        encoder_pos_conv_kernel: int,
        encoder_pos_conv_groups: int,
        encoder_num_layers: int,
        encoder_num_heads: int,
        encoder_attention_dropout: float,
        encoder_ff_interm_features: int,
        encoder_ff_interm_dropout: float,
        encoder_dropout: float,
        encoder_layer_norm_first: bool,
        encoder_layer_drop: float,
119
        aux_num_out: Optional[int],
moto's avatar
moto committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
) -> Wav2Vec2Model:
    if extractor_conv_layer_config is None:
        extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2

    feature_extractor = components._get_feature_extractor(
        extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
    encoder = components._get_encoder(
        in_features=extractor_conv_layer_config[-1][0],
        embed_dim=encoder_embed_dim,
        dropout_input=encoder_projection_dropout,
        pos_conv_kernel=encoder_pos_conv_kernel,
        pos_conv_groups=encoder_pos_conv_groups,
        num_layers=encoder_num_layers,
        num_heads=encoder_num_heads,
        attention_dropout=encoder_attention_dropout,
        ff_interm_features=encoder_ff_interm_features,
        ff_interm_dropout=encoder_ff_interm_dropout,
        dropout=encoder_dropout,
        layer_norm_first=encoder_layer_norm_first,
        layer_drop=encoder_layer_drop,
    )
141
142
143
    aux = None
    if aux_num_out is not None:
        aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
144
    return Wav2Vec2Model(feature_extractor, encoder, aux)
moto's avatar
moto committed
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def wav2vec2_base() -> Wav2Vec2Model:
    """Build wav2vec2 model with "base" configuration

    This is one of the model architecture used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for pretraining.

    Returns:
        Wav2Vec2Model:
    """
    return _get_model(
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=768,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=12,
        encoder_num_heads=12,
        encoder_attention_dropout=0.1,
        encoder_ff_interm_features=3072,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.1,
        encoder_layer_norm_first=False,
        encoder_layer_drop=0.1,
        aux_num_out=None,
    )


176
def wav2vec2_ft_base(num_out: int) -> Wav2Vec2Model:
177
178
179
180
    """Build "base" wav2vec2 with an extra linear module

    This is one of the model architectures used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
moto's avatar
moto committed
181
182
183
184
185
186

    Args:
        num_out: int
            The number of output labels.

    Returns:
187
        Wav2Vec2Model:
moto's avatar
moto committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    """
    return _get_model(
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=768,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=12,
        encoder_num_heads=12,
        encoder_attention_dropout=0.1,
        encoder_ff_interm_features=3072,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.1,
        encoder_layer_norm_first=False,
        encoder_layer_drop=0.1,
205
        aux_num_out=num_out,
moto's avatar
moto committed
206
207
208
    )


209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def wav2vec2_large() -> Wav2Vec2Model:
    """Build wav2vec2 model with "large" configuration

    This is one of the model architecture used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for pretraining.

    Returns:
        Wav2Vec2Model:
    """
    return _get_model(
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.1,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.1,
        encoder_layer_norm_first=False,
        encoder_layer_drop=0.1,
        aux_num_out=None,
    )


238
def wav2vec2_ft_large(num_out: int) -> Wav2Vec2Model:
239
240
241
242
    """Build "large" wav2vec2.0 model with an extra linear module

    This is one of the model architectures used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
moto's avatar
moto committed
243
244
245
246
247
248

    Args:
        num_out: int
            The number of output labels.

    Returns:
249
        Wav2Vec2Model:
moto's avatar
moto committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    """
    return _get_model(
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.1,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.1,
        encoder_layer_norm_first=False,
        encoder_layer_drop=0.1,
267
        aux_num_out=num_out,
moto's avatar
moto committed
268
269
270
    )


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def wav2vec2_large_lv60k() -> Wav2Vec2Model:
    """Build wav2vec2.0 model with "Large LV-60k" configuration

    This is one of the model architectures used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for pretraining.

    Returns:
        Wav2Vec2Model: The resulting model.
    """
    return _get_model(
        extractor_mode="layer_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=True,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.1,
        aux_num_out=None,
    )


300
def wav2vec2_ft_large_lv60k(num_out: int) -> Wav2Vec2Model:
301
302
303
304
    """Build "Large LV-60k" wav2vec2.0 with an extra linear module

    This is one of the model architectures used in *wav2vec 2.0*
    [:footcite:`baevski2020wav2vec`] for fine-tuning for ASR task.
moto's avatar
moto committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

    Args:
        num_out: int
            The number of output labels.

    Returns:
        Wav2Vec2Model: The resulting model.
    """
    return _get_model(
        extractor_mode="layer_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=True,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.1,
329
        aux_num_out=num_out,
moto's avatar
moto committed
330
    )
moto's avatar
moto committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


def hubert_base() -> Wav2Vec2Model:
    """Build HuBERT model with "Base" configuration

    This is one of the model architectures used in *HuBERT*
    [:footcite:`hsu2021hubert`] for pretraining.

    Returns:
        HuBERT: The resulting model.
    """
    return _get_model(
        extractor_mode='group_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=768,
        encoder_projection_dropout=0.1,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=12,
        encoder_num_heads=12,
        encoder_attention_dropout=0.1,
        encoder_ff_interm_features=3072,
        encoder_ff_interm_dropout=0.0,
        encoder_dropout=0.1,
        encoder_layer_norm_first=False,
        encoder_layer_drop=0.05,
        aux_num_out=None,
    )


def hubert_large() -> Wav2Vec2Model:
    """Build HuBERT model with "Large" configuration

    This is one of the model architectures used in *HuBERT*
    [:footcite:`hsu2021hubert`] for pretraining.

    Returns:
        HuBERT: The resulting model.
    """
    return _get_model(
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.0,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.0,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.0,
        aux_num_out=None,
    )


391
def hubert_ft_large(num_out) -> Wav2Vec2Model:
moto's avatar
moto committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    """Build "Large" HuBERT model with an extra linear module


    This is one of the model architecture used in *HuBERT*
    [:footcite:`hsu2021hubert`] for fine-tuning for ASR task.

    Args:
        num_out: int
            The number of output labels.

    Returns:
        Wav2Vec2Model:
    """
    return _get_model(
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
        encoder_projection_dropout=0.0,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=4096,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.1,
        aux_num_out=num_out,
    )


def hubert_xlarge() -> Wav2Vec2Model:
    """Build HuBERT model with "extra large" configuration

    This is one of the model architectures used in *HuBERT*
    [:footcite:`hsu2021hubert`] for pretraining.

    Returns:
        HuBERT: The resulting model.
    """
    return _get_model(
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
438
        encoder_embed_dim=1280,
moto's avatar
moto committed
439
440
441
        encoder_projection_dropout=0.0,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
442
        encoder_num_layers=48,
moto's avatar
moto committed
443
444
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
445
        encoder_ff_interm_features=5120,
moto's avatar
moto committed
446
447
448
449
450
451
452
453
        encoder_ff_interm_dropout=0.0,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.0,
        aux_num_out=None,
    )


454
def hubert_ft_xlarge(num_out) -> Wav2Vec2Model:
moto's avatar
moto committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    """Build "extra large" HuBERT model with an extra linear module

    This is one of the model architecture used in *HuBERT*
    [:footcite:`hsu2021hubert`] for fine-tuning for ASR task.

    Args:
        num_out: int
            The number of output labels.

    Returns:
        Wav2Vec2Model: The resulting model.
    """
    return _get_model(
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1280,
        encoder_projection_dropout=0.0,
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=48,
        encoder_num_heads=16,
        encoder_attention_dropout=0.0,
        encoder_ff_interm_features=5120,
        encoder_ff_interm_dropout=0.1,
        encoder_dropout=0.0,
        encoder_layer_norm_first=True,
        encoder_layer_drop=0.1,
        aux_num_out=num_out,
    )