model.py 23.9 KB
Newer Older
moto's avatar
moto committed
1
2
from typing import Optional, Tuple, List

3
import torch
moto's avatar
moto committed
4
5
6
7
8
9
10
from torch import Tensor
from torch.nn import Module

from . import components


class Wav2Vec2Model(Module):
moto's avatar
moto committed
11
12
13
    """torchaudio.models.Wav2Vec2Model(feature_extractor: torch.nn.Module, encoder: torch.nn.Module, aux: Optional[torch.nn.Module] = None)

    Encoder model used in *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
moto's avatar
moto committed
14
15
16
17
18
19
20
21
22
23
24

    Note:
        To build the model, please use one of the factory functions.

    Args:
        feature_extractor (torch.nn.Module):
            Feature extractor that extracts feature vectors from raw audio Tensor.

        encoder (torch.nn.Module):
            Encoder that converts the audio features into the sequence of probability
            distribution (in negative log-likelihood) over labels.
25
26
27

        aux (torch.nn.Module or None, optional):
            Auxiliary module. If provided, the output from encoder is passed to this module.
moto's avatar
moto committed
28
    """  # noqa: E501
moto's avatar
moto committed
29
30
31
32
    def __init__(
            self,
            feature_extractor: Module,
            encoder: Module,
33
            aux: Optional[Module] = None,
moto's avatar
moto committed
34
35
36
37
    ):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.encoder = encoder
38
        self.aux = aux
moto's avatar
moto committed
39

40
    @torch.jit.export
moto's avatar
moto committed
41
42
43
44
    def extract_features(
            self,
            waveforms: Tensor,
            lengths: Optional[Tensor] = None,
45
46
            num_layers: Optional[int] = None,
    ) -> Tuple[List[Tensor], Optional[Tensor]]:
moto's avatar
moto committed
47
48
        """Extract feature vectors from raw waveforms

49
50
51
        This returns the list of outputs from the intermediate layers of
        transformer block in encoder.

moto's avatar
moto committed
52
        Args:
53
            waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
54
            lengths (Tensor or None, optional):
55
                Indicates the valid length of each audio in the batch.
56
                Shape: `(batch, )`.
57
58
59
60
61
62
                When the ``waveforms`` contains audios with different durations,
                by providing ``lengths`` argument, the model will compute
                the corresponding valid output lengths and apply proper mask in
                transformer attention layer.
                If ``None``, it is assumed that the entire audio waveform
                length is valid.
63
64
65
66
67
            num_layers (int or None, optional):
                If given, limit the number of intermediate layers to go through.
                Providing `1` will stop the computation after going through one
                intermediate layers. If not given, the outputs from all the
                intermediate layers are returned.
moto's avatar
moto committed
68
69

        Returns:
70
            (List[Tensor], Optional[Tensor]):
71
72
            List of Tensors
                Features from requested layers.
73
                Each Tensor is of shape: `(batch, time frame, feature dimension)`
74
            Tensor or None
75
                If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
76
77
                is returned.
                It indicates the valid length in time axis of each feature Tensor.
moto's avatar
moto committed
78
        """
79
80
81
        x, lengths = self.feature_extractor(waveforms, lengths)
        x = self.encoder.extract_features(x, lengths, num_layers)
        return x, lengths
moto's avatar
moto committed
82
83
84
85
86
87
88
89
90

    def forward(
            self,
            waveforms: Tensor,
            lengths: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Compute the sequence of probability distribution over labels.

        Args:
91
            waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
92
            lengths (Tensor or None, optional):
93
                Indicates the valid length of each audio in the batch.
94
                Shape: `(batch, )`.
95
                When the ``waveforms`` contains audios with different durations,
96
97
98
99
100
                by providing ``lengths`` argument, the model will compute
                the corresponding valid output lengths and apply proper mask in
                transformer attention layer.
                If ``None``, it is assumed that all the audio in ``waveforms``
                have valid length. Default: ``None``.
moto's avatar
moto committed
101
102

        Returns:
103
            (Tensor, Optional[Tensor]):
104
            Tensor
moto's avatar
moto committed
105
                The sequences of probability distribution (in logit) over labels.
106
                Shape: `(batch, frames, num labels)`.
107
            Tensor or None
108
                If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
109
                is returned.
110
                It indicates the valid length in time axis of the output Tensor.
moto's avatar
moto committed
111
112
        """
        x, lengths = self.feature_extractor(waveforms, lengths)
113
114
115
116
        x = self.encoder(x, lengths)
        if self.aux is not None:
            x = self.aux(x)
        return x, lengths
moto's avatar
moto committed
117
118


119
def wav2vec2_model(
moto's avatar
moto committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        extractor_mode: str,
        extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
        extractor_conv_bias: bool,
        encoder_embed_dim: int,
        encoder_projection_dropout: float,
        encoder_pos_conv_kernel: int,
        encoder_pos_conv_groups: int,
        encoder_num_layers: int,
        encoder_num_heads: int,
        encoder_attention_dropout: float,
        encoder_ff_interm_features: int,
        encoder_ff_interm_dropout: float,
        encoder_dropout: float,
        encoder_layer_norm_first: bool,
        encoder_layer_drop: float,
135
        aux_num_out: Optional[int],
moto's avatar
moto committed
136
) -> Wav2Vec2Model:
moto's avatar
moto committed
137
138
139
140
    # Overriding the signature so that the return type is correct on Sphinx
    """wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model

    Build a custom Wav2Vec2Model
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

    Note:
        The "feature extractor" below corresponds to
        `ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
        in the original ``fairseq`` implementation.
        This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
        [:footcite:`baevski2020wav2vec`] paper.

        The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
        and this is referred as "Transformer" in the paper.

    Args:
        extractor_mode (str): Operation mode of feature extractor.
            Valid values are ``"group_norm"`` or ``"layer_norm"``.
            If ``"group_norm"``, then a single normalization is applied
            in the first convolution block. Otherwise, all the convolution
            blocks will have layer normalization.

            This option corresponds to ``extractor_mode`` from ``fairseq``.
        extractor_conv_layer_config (list of integer tuples or None):
            Configuration of convolution layers in feature extractor.
            List of convolution configuration,
            i.e. ``[(output_channel, kernel_size, stride), ...]``

            If ``None`` is provided, then the following default value is used.

            .. code-block:: python

               [
                 (512, 10, 5),
                 (512, 3, 2),
                 (512, 3, 2),
                 (512, 3, 2),
                 (512, 3, 2),
                 (512, 2, 2),
                 (512, 2, 2),
               ]

            This option corresponds to ``conv_feature_layers`` from ``fairseq``.

        extractor_conv_bias (bool):
            Whether to include bias term to each convolution operation.

            This option corresponds to ``conv_bias`` from ``fairseq``.

        encoder_embed_dim (int):
            The dimension of embedding in encoder.

            This option corresponds to ``encoder_embed_dim`` from ``fairseq``.

        encoder_projection_dropout (float):
            The dropout probability applied after the input feature is projected
            to ``encoder_embed_dim``.

            This option corresponds to ``dropout_input`` from ``fairseq``.

        encoder_pos_conv_kernel (int):
            The kernel size of convolutional positional embeddings.

            This option corresponds to ``conv_pos`` from ``fairseq``.

        encoder_pos_conv_groups (int):
            The number of groups of convolutional positional embeddings.

            This option corresponds to ``conv_pos_groups`` from ``fairseq``.

        encoder_num_layers (int):
            The number of self attention layers in transformer block.

            This option corresponds to ``encoder_layers`` from ``fairseq``.

        encoder_num_heads (int):
            The number of heads in self attention layers.

            This option corresponds to ``encoder_attention_heads`` from ``fairseq``.

        encoder_attention_dropout (float):
            The dropout probability applied after softmax in self-attention layer.

            This option corresponds to ``attention_dropout`` from ``fairseq``.

        encoder_ff_interm_features (int):
            The dimension of hidden features in feed forward layer.

            This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.

        encoder_ff_interm_dropout (float):
            The dropout probability applied in feedforward layer.

            This option correspinds to ``activation_dropout`` from ``fairseq``.

        encoder_dropout (float):
            The dropout probability applied at the end of feed forward layer.

            This option corresponds to ``dropout`` from ``fairseq``.

        encoder_layer_norm_first (bool):
            Control the order of layer norm in transformer layer and each encoder layer.
            If True, in transformer layer, layer norm is applied before features are fed
            to encoder layers. In encoder layer, two layer norms are applied before and after
            self attention.
            If False, in transformer layer, layer norm is applied after features are fed
            to encoder layers. In encoder layer, two layer norms are applied after self
            attention, before and after feed forward.

            This option corresponds to ``layer_norm_first`` from ``fairseq``.

        encoder_layer_drop (float):
            Probability to drop each encoder layer during training.

            This option corresponds to ``layerdrop`` from ``fairseq``.

        aux_num_out (int or None):
            When provided, attach an extra linear layer on top of encoder, which can be
            used for fine-tuning.

    Returns:
        Wav2Vec2Model:
            The resulting model.
    """  # noqa: E501
moto's avatar
moto committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    if extractor_conv_layer_config is None:
        extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2

    feature_extractor = components._get_feature_extractor(
        extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
    encoder = components._get_encoder(
        in_features=extractor_conv_layer_config[-1][0],
        embed_dim=encoder_embed_dim,
        dropout_input=encoder_projection_dropout,
        pos_conv_kernel=encoder_pos_conv_kernel,
        pos_conv_groups=encoder_pos_conv_groups,
        num_layers=encoder_num_layers,
        num_heads=encoder_num_heads,
        attention_dropout=encoder_attention_dropout,
        ff_interm_features=encoder_ff_interm_features,
        ff_interm_dropout=encoder_ff_interm_dropout,
        dropout=encoder_dropout,
        layer_norm_first=encoder_layer_norm_first,
        layer_drop=encoder_layer_drop,
    )
281
282
283
    aux = None
    if aux_num_out is not None:
        aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
284
    return Wav2Vec2Model(feature_extractor, encoder, aux)
moto's avatar
moto committed
285
286


287
288
289
290
291
292
def wav2vec2_base(
        encoder_projection_dropout: float = 0.1,
        encoder_attention_dropout: float = 0.1,
        encoder_ff_interm_dropout: float = 0.1,
        encoder_dropout: float = 0.1,
        encoder_layer_drop: float = 0.1,
293
        aux_num_out: Optional[int] = None,
294
) -> Wav2Vec2Model:
moto's avatar
moto committed
295
296
297
298
    # Overriding the signature so that the return type is correct on Sphinx
    """wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
299

300
301
302
303
304
305
306
307
308
309
310
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
311
        aux_num_out (int or None, optional):
312
            See :py:func:`wav2vec2_model`.
moto's avatar
moto committed
313
314

    Returns:
315
        Wav2Vec2Model:
316
            The resulting model.
moto's avatar
moto committed
317
    """  # noqa: E501
318
    return wav2vec2_model(
moto's avatar
moto committed
319
320
321
322
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=768,
323
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
324
325
326
327
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=12,
        encoder_num_heads=12,
328
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
329
        encoder_ff_interm_features=3072,
330
331
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
332
        encoder_layer_norm_first=False,
333
334
        encoder_layer_drop=encoder_layer_drop,
        aux_num_out=aux_num_out,
moto's avatar
moto committed
335
336
337
    )


338
339
340
341
342
343
def wav2vec2_large(
        encoder_projection_dropout: float = 0.1,
        encoder_attention_dropout: float = 0.1,
        encoder_ff_interm_dropout: float = 0.1,
        encoder_dropout: float = 0.1,
        encoder_layer_drop: float = 0.1,
344
        aux_num_out: Optional[int] = None,
345
) -> Wav2Vec2Model:
moto's avatar
moto committed
346
347
348
349
    # Overriding the signature so that the return type is correct on Sphinx
    """wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
350

351
352
353
354
355
356
357
358
359
360
361
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
362
        aux_num_out (int or None, optional):
363
            See :py:func:`wav2vec2_model`.
moto's avatar
moto committed
364
365

    Returns:
366
        Wav2Vec2Model:
367
            The resulting model.
moto's avatar
moto committed
368
    """  # noqa: E501
369
    return wav2vec2_model(
moto's avatar
moto committed
370
371
372
373
        extractor_mode="group_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
374
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
375
376
377
378
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
379
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
380
        encoder_ff_interm_features=4096,
381
382
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
383
        encoder_layer_norm_first=False,
384
385
        encoder_layer_drop=encoder_layer_drop,
        aux_num_out=aux_num_out,
moto's avatar
moto committed
386
387
388
    )


389
390
391
392
393
394
def wav2vec2_large_lv60k(
        encoder_projection_dropout: float = 0.1,
        encoder_attention_dropout: float = 0.0,
        encoder_ff_interm_dropout: float = 0.1,
        encoder_dropout: float = 0.0,
        encoder_layer_drop: float = 0.1,
395
        aux_num_out: Optional[int] = None,
396
) -> Wav2Vec2Model:
moto's avatar
moto committed
397
398
399
400
    # Overriding the signature so that the return type is correct on Sphinx
    """wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
401

402
403
404
405
406
407
408
409
410
411
412
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
413
        aux_num_out (int or None, optional):
414
415
            See :py:func:`wav2vec2_model`.

moto's avatar
moto committed
416
    Returns:
417
418
        Wav2Vec2Model:
            The resulting model.
moto's avatar
moto committed
419
    """  # noqa: E501
420
    return wav2vec2_model(
moto's avatar
moto committed
421
422
423
424
        extractor_mode="layer_norm",
        extractor_conv_layer_config=None,
        extractor_conv_bias=True,
        encoder_embed_dim=1024,
425
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
426
427
428
429
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
430
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
431
        encoder_ff_interm_features=4096,
432
433
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
434
        encoder_layer_norm_first=True,
435
436
        encoder_layer_drop=encoder_layer_drop,
        aux_num_out=aux_num_out,
moto's avatar
moto committed
437
    )
moto's avatar
moto committed
438
439


440
441
442
443
444
445
def hubert_base(
        encoder_projection_dropout: float = 0.1,
        encoder_attention_dropout: float = 0.1,
        encoder_ff_interm_dropout: float = 0.0,
        encoder_dropout: float = 0.1,
        encoder_layer_drop: float = 0.05,
446
        aux_num_out: Optional[int] = None,
447
) -> Wav2Vec2Model:
moto's avatar
moto committed
448
449
450
451
    # Overriding the signature so that the return type is correct on Sphinx
    """hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build HuBERT model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
moto's avatar
moto committed
452

453
454
455
456
457
458
459
460
461
462
463
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
464
465
        aux_num_out (int or None, optional):
            See :py:func:`wav2vec2_model`.
466

moto's avatar
moto committed
467
    Returns:
468
469
        Wav2Vec2Model:
            The resulting model.
moto's avatar
moto committed
470
    """  # noqa: E501
471
    return wav2vec2_model(
moto's avatar
moto committed
472
473
474
475
        extractor_mode='group_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=768,
476
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
477
478
479
480
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=12,
        encoder_num_heads=12,
481
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
482
        encoder_ff_interm_features=3072,
483
484
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
485
        encoder_layer_norm_first=False,
486
        encoder_layer_drop=encoder_layer_drop,
487
        aux_num_out=aux_num_out,
moto's avatar
moto committed
488
489
490
    )


491
492
493
494
495
496
def hubert_large(
        encoder_projection_dropout: float = 0.0,
        encoder_attention_dropout: float = 0.0,
        encoder_ff_interm_dropout: float = 0.0,
        encoder_dropout: float = 0.0,
        encoder_layer_drop: float = 0.0,
497
        aux_num_out: Optional[int] = None,
498
) -> Wav2Vec2Model:
moto's avatar
moto committed
499
500
501
502
    # Overriding the signature so that the return type is correct on Sphinx
    """hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build HuBERT model with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
moto's avatar
moto committed
503

504
505
506
507
508
509
510
511
512
513
514
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
515
        aux_num_out (int or None, optional):
516
            See :py:func:`wav2vec2_model`.
moto's avatar
moto committed
517
518
519

    Returns:
        Wav2Vec2Model:
520
            The resulting model.
moto's avatar
moto committed
521
    """  # noqa: E501
522
    return wav2vec2_model(
moto's avatar
moto committed
523
524
525
526
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1024,
527
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
528
529
530
531
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=24,
        encoder_num_heads=16,
532
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
533
        encoder_ff_interm_features=4096,
534
535
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
536
        encoder_layer_norm_first=True,
537
538
        encoder_layer_drop=encoder_layer_drop,
        aux_num_out=aux_num_out,
moto's avatar
moto committed
539
540
541
    )


542
543
544
545
546
547
def hubert_xlarge(
        encoder_projection_dropout: float = 0.0,
        encoder_attention_dropout: float = 0.0,
        encoder_ff_interm_dropout: float = 0.0,
        encoder_dropout: float = 0.0,
        encoder_layer_drop: float = 0.0,
548
        aux_num_out: Optional[int] = None,
549
) -> Wav2Vec2Model:
moto's avatar
moto committed
550
551
552
553
    # Overriding the signature so that the return type is correct on Sphinx
    """hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model

    Build HuBERT model with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
moto's avatar
moto committed
554

555
556
557
558
559
560
561
562
563
564
565
    Args:
        encoder_projection_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_attention_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_ff_interm_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_dropout (float):
            See :py:func:`wav2vec2_model`.
        encoder_layer_drop (float):
            See :py:func:`wav2vec2_model`.
566
        aux_num_out (int or None, optional):
567
            See :py:func:`wav2vec2_model`.
moto's avatar
moto committed
568
569

    Returns:
570
571
        Wav2Vec2Model:
            The resulting model.
moto's avatar
moto committed
572
    """  # noqa: E501
573
    return wav2vec2_model(
moto's avatar
moto committed
574
575
576
577
        extractor_mode='layer_norm',
        extractor_conv_layer_config=None,
        extractor_conv_bias=False,
        encoder_embed_dim=1280,
578
        encoder_projection_dropout=encoder_projection_dropout,
moto's avatar
moto committed
579
580
581
582
        encoder_pos_conv_kernel=128,
        encoder_pos_conv_groups=16,
        encoder_num_layers=48,
        encoder_num_heads=16,
583
        encoder_attention_dropout=encoder_attention_dropout,
moto's avatar
moto committed
584
        encoder_ff_interm_features=5120,
585
586
        encoder_ff_interm_dropout=encoder_ff_interm_dropout,
        encoder_dropout=encoder_dropout,
moto's avatar
moto committed
587
        encoder_layer_norm_first=True,
588
589
        encoder_layer_drop=encoder_layer_drop,
        aux_num_out=aux_num_out,
moto's avatar
moto committed
590
    )