phi4mm_utils.py 65.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import math
9
from typing import Optional, Union
10
11
12
13
14
15

import torch
import torch.nn.functional as F
from torch import Tensor, nn


16
class BlockBase(nn.Module):
17
18
    """Block abstract module"""

19
    def __init__(self, input_size: int, output_size: int) -> None:
20
21
22
23
24
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size


25
def get_activation(name: str = "relu") -> torch.nn.Module:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    """Select an activation function by name

    Args:
        name: str
            activation function name,
            one of ["relu", "gelu", "swish", "sigmoid"],
            default "relu".
    """
    name = name.lower()
    if name == "relu":
        return nn.ReLU(inplace=True)
    if name == "gelu":
        return nn.GELU()
    if name == "swish":
        return Swish()
    if name == "sigmoid":
        return torch.nn.Sigmoid()
    return nn.Identity()


46
47
48
def adaptive_enc_mask(
    x_len: int, chunk_start_idx: list[int], left_window: int = 0, right_window: int = 0
) -> torch.Tensor:
49
50
51
    """
    The function is very important for Transformer Transducer Streaming mode
    Args:
52
        x_len: sequence length
53
        chunk_start_idx: first idx of each chunk, such as [0,18,36,48].
54
        It also supports adaptive chunk size [0,10,15,45]
55
        left_window: how many left chunks can be seen
56
        right_window: how many right chunks can be seen. It is used for
57
58
59
60
61
62
63
64
65
66
67
68
        chunk overlap model.
        Returns:
            mask (torch.Tensor): a mask tensor for streaming model
            Torch 1.0.1
            tensor([[1., 1., 0., 0.],
                    [0., 1., 1., 0.],
                    [0., 0., 1., 1.]])
            Torch 1.4.1
            tensor([[True., True., False., False.],
                    [False., True., True., False.],
                    [False., False., True., True.]])
    """
69
70
71
    chunk_start_idx = torch.Tensor(
        chunk_start_idx
    ).long()  # first idx of each chunk, such as [0,18,36,48].
72
    start_pad = torch.nn.functional.pad(
73
74
        chunk_start_idx, (1, 0)
    )  # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
75
76
77
    end_pad = torch.nn.functional.pad(
        chunk_start_idx, (0, 1), value=x_len
    )  # append x_len to the end, so it becomes [0,18,36,48, x_len]
78
79
80
81
    seq_range = torch.arange(0, x_len).unsqueeze(-1)  # seq_range size: [x_len, 1]
    idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[
        :, 1
    ]  # idx size: [x_len]
82
    # boundary = end_pad[idx]  # boundary size: [x_len]
83
84
85
    seq_range_expand = (
        torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
    )  # seq_range_expand size [x_len, x_len]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    idx_left = idx - left_window
    idx_left[idx_left < 0] = 0
    boundary_left = start_pad[idx_left]
    mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
    idx_right = idx + right_window
    idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
    boundary_right = end_pad[idx_right]
    mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
    return mask_left & mask_right


class Swish(nn.Module):
    """Implement Swish activation module.
    From https://arxiv.org/pdf/2005.03191.pdf

    """

    def __init__(self) -> None:
        super().__init__()
        self.act_fn = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        """Apply Swish function

        Args:
            x: torch.Tensor
                Input.
        """
        return x * self.act_fn(x)


class GLU(nn.Module):
    """Implement Gated Linear Unit (GLU) module"""

    def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
        super().__init__()
        self.dim = dim
        self.act_name = act_name.lower()

        if self.act_name == "relu":
            self.act_fn = nn.ReLU(inplace=True)
        elif self.act_name == "gelu":
            self.act_fn = nn.GELU()
        elif self.act_name == "swish":
            self.act_fn = Swish()
        elif self.act_name == "sigmoid":
            self.act_fn = nn.Sigmoid()
        else:
            self.act_fn = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        """GLU forward
        Apply Swish function on the first half of input matrices
        with sigmoid of the second half.

        Args:
            x: torch.Tensor
                Input.

        """
        half_x, gate = x.chunk(2, dim=self.dim)
        return half_x * self.act_fn(gate)


# TODO: Abdel, this can be improved using GLU module
class GLUPointWiseConv(nn.Module):
    """GLUPointWiseConv module
    used for conformer architecture,
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        output_dim: int
            output channel size.
        kernel_size: int
            kernel size
        glu_type: str, optional
            activation function one of
             ["sigmoid", "relu", "gelu"]
              default "sigmoid".
        bias_in_glu: bool, optional
            use addtive bias in glu
        causal: bool, optional
            if set to True, padding is set to the half of
             kernel size, ie, convolution can't see future frames.
              default False.

    """

    def __init__(
        self,
179
180
181
182
183
184
185
        input_dim: int,
        output_dim: int,
        kernel_size: int,
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        causal: bool = False,
    ) -> None:
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        super().__init__()

        self.glu_type = glu_type
        self.output_dim = output_dim
        self.bias_in_glu = bias_in_glu
        if causal:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim,
                output_dim * 2,
                kernel_size,
                1,
                padding=(kernel_size - 1),
            )
        else:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim,
                output_dim * 2,
                kernel_size,
                1,
                padding=(kernel_size - 1) // 2,
            )

        if glu_type == "sigmoid":
            self.glu_act = nn.Sigmoid()
        elif glu_type == "relu":
            self.glu_act = nn.ReLU()
        elif glu_type == "gelu":
            self.glu_act = nn.GELU()
        elif glu_type == "swish":
            self.glu_act = Swish()
        else:
            raise ValueError(f"Unsupported activation type {self.glu_act}")

        if bias_in_glu:
            self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
            self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))

223
    def forward(self, x: Tensor) -> Tensor:
224
225
        """
        Args:
226
            x: input tensor
227
228
229
230
231
232
233
234
        """
        # to be consistent with GLULinear, we assume the input always has the
        # #channel (#dim) in the last dimension of the tensor, so need to
        # switch the dimension first for 1D-Conv case
        x = x.permute([0, 2, 1])
        x = self.ext_pw_conv_1d(x)
        if self.glu_type == "bilinear":
            if self.bias_in_glu:
235
236
237
                x = (x[:, 0 : self.output_dim, :] + self.b1) * (
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
238
            else:
239
240
241
242
                x = (
                    (x[:, 0 : self.output_dim, :])
                    * (x[:, self.output_dim : self.output_dim * 2, :])
                )
243
244
        else:
            if self.bias_in_glu:
245
246
247
                x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
248
            else:
249
250
251
                x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :]
                )
252
253
254
255
256
257
258
259
260
261
262
263
264
265

        x = x.permute([0, 2, 1])
        return x


class DepthWiseSeperableConv1d(nn.Module):
    """DepthWiseSeperableConv1d module used in Convnet module
    for the conformer, for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        depthwise_seperable_out_channel: int
266
            if set different to 0, the number of
267
268
             depthwise_seperable_out_channel will be used as a channel_out
             of the second conv1d layer.
269
             otherwise, it equals to 0, the second conv1d layer is skipped.
270
271
272
273
274
275
276
277
278
279
280
281
282
        kernel_size: int
            kernel_size
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
            will be used to compute the hidden channels of the Conv1D.
        padding: int, optional
            padding for the conv1d,
             default: 0.

    """

    def __init__(
        self,
283
284
285
286
287
288
        input_dim: int,
        depthwise_seperable_out_channel: int,
        kernel_size: int,
        depthwise_multiplier: int,
        padding: int = 0,
    ) -> None:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        super().__init__()

        self.dw_conv = nn.Conv1d(
            input_dim,
            input_dim * depthwise_multiplier,
            kernel_size,
            1,
            padding=padding,
            groups=input_dim,
        )

        if depthwise_seperable_out_channel != 0:
            self.pw_conv = nn.Conv1d(
                input_dim * depthwise_multiplier,
                depthwise_seperable_out_channel,
                1,
                1,
                0,
            )
        else:
            self.pw_conv = nn.Identity()
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel

312
    def forward(self, x: Tensor) -> Tensor:
313
314
315
        """

        Args:
316
            x: input tensor
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        """
        x = self.dw_conv(x)
        if self.depthwise_seperable_out_channel != 0:
            x = self.pw_conv(x)
        return x


class ConvModule(nn.Module):
    """ConvModule Module for the conformer block.
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        ext_pw_out_channel: int
            if > 0, ext_pw_out_channel is a dim channel size
             for the last pointwise conv after swish activation.
        depthwise_seperable_out_channel: int
336
            if set different to 0, the number of
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
             depthwise_seperable_out_channel
             will be used as a channel_out of the second conv1d layer.
             otherwise, it equal to 0, the second conv1d layer is skipped.
        ext_pw_kernel_size: int
            kernel size of the conv pointwise of the conformer.
        kernel_size: int
            kernel size.
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
             will be used to compute the hidden channels of the Conv1D.
        dropout_rate: float
            dropout rate.
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation.
            default False
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
        chunk_size: int, optional
            chunk size for cnn. default 18
        activation: str, optional
            activation function used in ConvModule,
            default: "relu".
        glu_type: str, optional
            activation function used for the glu,
            default: "sigmoid".
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU.
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        export: bool, optional,
            if set to True, padding is equal to 0.  This is for inference,
             or onnx export.  Typically this is set by the export program or
             the decoder program, and it isn't present in your config file.
             default False
    """

    def __init__(
        self,
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        input_dim: int,
        ext_pw_out_channel: int,
        depthwise_seperable_out_channel: int,
        ext_pw_kernel_size: int,
        kernel_size: int,
        depthwise_multiplier: int,
        dropout_rate: float,
        causal: bool = False,
        batch_norm: bool = False,
        chunk_se: int = 0,
        chunk_size: int = 18,
        activation: str = "relu",
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        linear_glu_in_convm: bool = False,
        export: bool = False,
    ) -> None:
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.input_dim = input_dim
        self.ext_pw_out_channel = ext_pw_out_channel
        self.ext_pw_kernel_size = ext_pw_kernel_size
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
        self.glu_type = glu_type
        self.bias_in_glu = bias_in_glu
        self.linear_glu_in_convm = linear_glu_in_convm
        self.causal = causal

        self._add_ext_pw_layer()

        self.batch_norm = batch_norm
        self.kernel_size = kernel_size

        if batch_norm:
            self.bn_layer = nn.BatchNorm1d(input_dim)

        self.act = get_activation(activation)
        self.dropout = nn.Dropout(dropout_rate)
        self.export = export

        if causal:
            padding = 0 if export else kernel_size - 1
        else:
            padding = (kernel_size - 1) // 2

        self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
            input_dim,
            depthwise_seperable_out_channel,
            kernel_size,
            depthwise_multiplier,
            padding=padding,
        )

        if depthwise_seperable_out_channel != 0:
            if input_dim != depthwise_seperable_out_channel:
440
                self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
441
442
        else:
            if depthwise_multiplier != 1:
443
                self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)
444

445
    def _add_ext_pw_layer(self) -> None:
446
447
448
449
450
451
        """
        This function is an extension of __init__ function
        and dedicated to the convolution module creation
        of the conformer.
        """
        self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
452
453
            nn.Identity()
        )  # jit hacks.
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        self.squeeze_excitation = nn.Identity()  # jit.
        self.apply_ln1 = self.fix_len1 = False  # jit.

        if self.ext_pw_out_channel != 0:
            if self.causal:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1),
                )
                if self.ext_pw_kernel_size > 1:
                    self.fix_len1 = True
                else:
                    self.fix_len1 = False
            else:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1) // 2,
                )
                self.fix_len1 = False

            if self.linear_glu_in_convm:
                self.glu = GLULinear(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.glu_type,
                    self.bias_in_glu,
                )
            else:
                self.glu = GLUPointWiseConv(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    self.glu_type,
                    self.bias_in_glu,
                    self.causal,
                )

            if self.input_dim != self.ext_pw_out_channel:
                self.apply_ln1 = True
                self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
            else:
                self.apply_ln1 = False
        else:
            self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
            self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))

506
    def forward(self, x: Tensor) -> Tensor:
507
508
509
        """ConvModule Forward.

        Args:
510
            x: input tensor.
511
512
513
514
515
516
        """
        x = self.layer_norm(x)

        if self.ext_pw_out_channel != 0:
            x = self.glu(x)
            if self.causal and self.ext_pw_kernel_size > 1:
517
                x = x[:, : -(self.ext_pw_kernel_size - 1), :]
518
519
520
521
522
523
524
525
526
527
528
            if self.apply_ln1:
                x = self.ln1(x)
        else:
            x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
            x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
            x = x_0 + x_1

        x = x.permute([0, 2, 1])

        x = self.dw_sep_conv_1d(x)
        if self.causal and self.kernel_size > 1:
529
            x = x[:, :, : -(self.kernel_size - 1)]
530
531
532
533
534
535
536
537
538
539
540
        if hasattr(self, "ln2"):
            x = x.permute([0, 2, 1])
            x = self.ln2(x)
            x = x.permute([0, 2, 1])
        if self.batch_norm:
            x = self.bn_layer(x)
        x = self.act(x)

        if self.ext_pw_out_channel != 0:
            x = self.ext_pw_conv_1d(x)
            if self.fix_len1:
541
                x = x[:, :, : -(self.ext_pw_kernel_size - 1)]
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

            if self.apply_ln1:
                x = x.permute([0, 2, 1])
                x = self.ln1(x)
                x = x.permute([0, 2, 1])

            x = x.permute([0, 2, 1])
        else:
            x = x.unsqueeze(1).permute([0, 1, 3, 2])
            x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
            x = x.squeeze(1)

        x = self.dropout(x)
        return x


class GLULinear(nn.Module):
    """Linear + GLU module

    Args:
        input_dim: int
            input size
        output_dim: int
            output size.
        glu_type:
            activation function name used in glu module.
            default "sigmoid" (swish function).
        bias_in_glu: bool, optional
            If True, the addtive bias is added. Default False.
    """

    def __init__(
        self,
575
576
577
578
579
        input_dim: int,
        output_dim: int,
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
    ) -> None:
580
581
582
583
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
        self.glu_act = GLU(-1, glu_type)

584
    def forward(self, x: Tensor) -> Tensor:
585
586
587
        """GLULinear forward

        Args:
588
            x: input tensor.
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        """
        x = self.linear(x)
        return self.glu_act(x)


class FeedForward(nn.Module):
    """FeedForward Module.
    For more details see Conformer paper:
        https://arxiv.org/pdf/2005.08100.pdf

    Args:
        d_model: int
            input size.
        d_inner: int
            output size.
        dropout_rate: float,
            dropout rate.
        activation: str,
            activation function name,
            one of ["relu", "swish", "sigmoid"],
            sigmoid activation is only used with "glu_in_fnn=True",
            default "sigmoid".
        bias_in_glu: bool, optional
    """

    def __init__(
        self,
616
617
618
619
620
621
        d_model: int,
        d_inner: int,
        dropout_rate: float,
        activation: str = "sigmoid",
        bias_in_glu: bool = True,
    ) -> None:
622
623
624
625
626
627
628
629
630
631
632
633
634
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner

        self.layer_norm = nn.LayerNorm(d_model)
        module = GLULinear(d_model, d_inner, activation, bias_in_glu)
        self.net = nn.Sequential(
            module,
            nn.Dropout(dropout_rate),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout_rate),
        )

635
    def forward(self, x: Tensor) -> Tensor:
636
637
638
        """FeedForward forward function.

        Args:
639
            x: input tensor.
640
641
642
643
644
645
646
647
        """
        out = self.net(self.layer_norm(x))

        return out


#### positional encoding starts here
def _pre_hook(
648
649
650
651
652
653
654
655
    state_dict: dict,
    prefix: str,
    local_metadata: dict,
    strict: bool,
    missing_keys: list[str],
    unexpected_keys: list[str],
    error_msgs: list[str],
) -> None:
656
657
658
659
    """Perform pre-hook in load_state_dict for backward compatibility.

    Note:
        We saved self.pe until v.0.5.2 but we have omitted it later.
660
        Therefore, we remove the item "pe" from `state_dict` for backward
661
662
663
664
665
666
667
668
669
670
        compatibility.

    """
    k = prefix + "pe"
    if k in state_dict:
        state_dict.pop(k)


class T5RelativeAttentionLogitBias(nn.Module):
    """
671
    This module implements the relative position bias described in Section
672
673
674
675
676
677
678
    2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf

    The Huggingface implementation is used as a reference
    https://github.com/huggingface/transformers/blob/v4.30.0/src/
    transformers/models/t5/modeling_t5.py#L435

    Modifies attention as Q*K^T + B, where B is a learned scalar bias based
679
    on relative position of the query and key. It is HxNxN, where H is the
680
681
682
    number of heads, N is the sequence length.

    I've made these modifications to the original T5 bias:
683
684
    - Skipping of the bucketing step. Original T5 bias converted rel
      position distances into logarithmically increasing buckets. This is
685
      supposed to help with length generalization.
686
687
    - I just directly use rel position index as bias values, as we don't
      need length generalization (40s max is good enough for ASR encoder),
688
      and it keeps ONNX export simple.
689
690
    - I've also extended it so that biases can be asymmetric, the default
      implementation treats L->R and R->L the same. Asymmetric was found to
691
692
693
694
695
696
697
      yield better results in my experiments.

    Args:
        num_heads: int
            Number of attention heads
        num_buckets: int
            Number of buckets to use for relative attention bias. This is the
698
            size of the learnable bias parameter. Bucketing is not yet
699
700
701
            supported, so this defaults to -1 which means no bucketing is
            used (max_distance determines size of bias param).
        max_distance: int
702
703
704
705
            Maximum distance to use for relative attention bias. With
            num_buckets=-1, this directly controls the max size of the bias
            parameter. When num_buckets > 0 is supported, this will control
            the maximum distance for logarithmic bucketing after which all
706
707
708
            positions are in the same bucket.
        symmetric: bool
            Whether to use symmetric or asymmetric biases. symmetric=False uses
709
            2x number of bias params to distinguish L->R from R->L. This was
710
711
712
            found to be better for the encoder.
    """

713
714
715
716
717
718
719
    def __init__(
        self,
        num_heads: int,
        num_buckets: int = -1,
        max_distance: int = 1000,
        symmetric: bool = False,
    ) -> None:
720
721
722
723
724
725
726
727
728
729
        super().__init__()
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.symmetric = symmetric
        self._skip_bucketing = self.num_buckets < 0
        if self._skip_bucketing:
            self.num_buckets = max_distance
        else:
            raise NotImplementedError(
730
731
                "T5 attention bias with bucketed positions is not yet tested"
            )
732
733
734
735
        if not self.symmetric:
            self.num_buckets *= 2
        self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)

736
    def forward(self, x: Tensor) -> Tensor:
737
738
        # instantiate bias compatible with shape of x
        maxpos = x.size(1)
739
740
741
742
743
744
        context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
            :, None
        ]
        memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
            None, :
        ]
745
746
747
748
        relative_position = memory_position - context_position
        # clipping to a maximum distance using ops that play well with ONNX
        # export
        relative_position = relative_position.masked_fill(
749
750
            relative_position < -self.max_distance, -self.max_distance
        )
751
        relative_position = relative_position.masked_fill(
752
753
            relative_position > self.max_distance - 1, self.max_distance - 1
        )
754
755
756
757
758
759
760
761
762
763
764
765

        # mapping from relative position to index in the bias parameter
        if self._skip_bucketing:
            bias_idx = relative_position
        else:
            bias_idx = self._bucket_relative_position(relative_position)
        if self.symmetric:
            bias_idx = bias_idx.abs()
        else:
            bias_idx += self.num_buckets // 2

        t5_rel_att_bias = self.bias_values(bias_idx)  # [L, L, H]
766
        t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0)  # [1, H, L, L]
767
768
769

        return t5_rel_att_bias

770
    def _bucket_relative_position(self, relative_position: Tensor) -> Tensor:
771
772
773
774
775
776
777
        # this is a placeholder (isn't tested, likely buggy) using HuggingFace
        # implem as a reference this also needs to be extended to support
        # asymmetric +/- ve positions
        relative_buckets = 0
        if not self.causal:
            self.num_buckets //= 2
            relative_buckets += (relative_position > 0).to(
778
779
                torch.long
            ) * self.num_buckets
780
781
            relative_position = torch.abs(relative_position)
        else:
782
783
784
            relative_position = -torch.min(
                relative_position, torch.zeros_like(relative_position)
            )
785
786
787
788
789
790
791
792
793
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = self.num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in
        # positions up to max_distance
        relative_position_if_large = max_exact + (
794
795
796
797
            torch.log(relative_position.float() / max_exact)
            / math.log(self.max_distance / max_exact)
            * (self.num_buckets - max_exact)
        ).to(torch.long)
798
799
800
801
802
        relative_position_if_large = torch.min(
            relative_position_if_large,
            torch.full_like(relative_position_if_large, self.num_buckets - 1),
        )

803
804
805
        relative_buckets += torch.where(
            is_small, relative_position, relative_position_if_large
        )
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        return relative_buckets


class AbsolutePositionalEncoding(nn.Module):
    """Absolute Positional encoding module.
    This module implement Absolute sinusoidal positional encoding
    from: https://arxiv.org/pdf/1706.03762.pdf

    Args:
        d_model: int
            Input embedding size.
        dropout_rate: float
            dropout rate
        max_len: int, optional
            Maximum input length sequence, Default 5000

    """

824
    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
825
826
827
828
829
830
831
832
833
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.xscale = math.sqrt(self.d_model)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
        self._register_load_state_dict_pre_hook(_pre_hook)

834
    def extend_pe(self, x: torch.Tensor) -> None:
835
836
837
        """Reset the positional encodings.

        Args:
838
            x: input tensor
839
840
841
842
843
844
845
846
        """
        if self.pe is not None and self.pe.size(1) >= x.size(1):
            if self.pe.dtype != x.dtype or self.pe.device != x.device:
                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
            return
        pe = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
847
848
849
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.d_model)
        )
850
851
852
853
854
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

855
    def forward(self, x: torch.Tensor) -> torch.Tensor:
856
857
858
        """Add positional encoding.

        Args:
859
            x: Input tensor. shape is (batch, time, ...)
860
861

        Returns:
862
            Encoded tensor. Its shape is (batch, time, ...)
863
864
865

        """
        self.extend_pe(x)
866
        x = x * self.xscale + self.pe[:, : x.size(1)]
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
        return self.dropout(x)


#### forward embedding layers starts here
class MeanVarianceNormLayer(nn.Module):
    """Mean/variance normalization layer.

    Will subtract mean and multiply input by inverted standard deviation.
    Typically used as a very first layer in a model.

    Args:
        input_size: int
            layer input size.
    """

882
    def __init__(self, input_size: int) -> None:
883
884
        super().__init__()
        self.input_size = input_size
885
886
        self.global_mean = nn.Parameter(torch.zeros(input_size))
        self.global_invstd = nn.Parameter(torch.ones(input_size))
887
888
889
890
891

    def forward(self, input_: Tensor) -> Tensor:
        """MeanVarianceNormLayer Forward

        Args:
892
            input_: input tensor.
893
894
895
896
897
898
899
900
901
902
        """
        return (input_ - self.global_mean) * self.global_invstd


class CausalConv1D(nn.Conv1d):
    """
    A causal version of nn.Conv1d where each step would have limited access to
    locations on its right or left
    All arguments are the same as nn.Conv1d except padding.

903
    If padding is set None, then paddings are set automatically to make it a
904
905
    causal convolution where each location would not see any steps on its right.

906
    If padding is set as a list (size of 2), then padding[0] would be used as
907
908
909
    left padding and padding[1] as right padding.
    It would make it possible to control the number of steps to be accessible
    on the right and left.
910
    This mode is not supported when stride > 1. padding[0]+padding[1] should
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
    be equal to (kernel_size - 1).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: Union[str, int] = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        self.cache_drop_size = None
        if padding is None:
            self._left_padding = kernel_size - 1
            self._right_padding = stride - 1
        else:
            if stride != 1 and padding != kernel_size - 1:
934
                raise ValueError("No striding allowed for non-symmetric convolutions!")
935
936
937
            if isinstance(padding, int):
                self._left_padding = padding
                self._right_padding = padding
938
939
940
941
942
            elif (
                isinstance(padding, list)
                and len(padding) == 2
                and padding[0] + padding[1] == kernel_size - 1
            ):
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
                self._left_padding = padding[0]
                self._right_padding = padding[1]
            else:
                raise ValueError(f"Invalid padding param: {padding}!")

        self._max_cache_len = self._left_padding

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )

964
    def update_cache(
965
966
        self, x: Tensor, cache: Optional[Tensor] = None
    ) -> tuple[Tensor, Optional[Tensor]]:
967
968
969
970
971
972
973
        if cache is None:
            new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
            next_cache = cache
        else:
            new_x = F.pad(x, pad=(0, self._right_padding))
            new_x = torch.cat([cache, new_x], dim=-1)
            if self.cache_drop_size > 0:
974
                next_cache = new_x[:, :, : -self.cache_drop_size]
975
976
            else:
                next_cache = new_x
977
            next_cache = next_cache[:, :, -cache.size(-1) :]
978
979
        return new_x, next_cache

980
    def forward(
981
        self, x: Tensor, cache: Optional[Tensor] = None
982
    ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]:
983
984
985
986
987
988
989
990
991
992
993
994
        x, cache = self.update_cache(x, cache=cache)
        x = super().forward(x)
        if cache is None:
            return x
        else:
            return x, cache


class CausalConv2D(nn.Conv2d):
    """
    A causal version of nn.Conv2d where each location in the 2D matrix would
    have no access to locations on its right or down
995
    All arguments are the same as nn.Conv2d except padding which should be
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    set as None
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: Union[str, int] = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        if padding is not None:
1014
            raise ValueError("Argument padding should be set to None for CausalConv2D.")
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        self._left_padding = kernel_size - 1
        self._right_padding = stride - 1

        padding = 0
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype,
        )

    def forward(
        self,
1035
1036
        x: Tensor,
    ) -> Tensor:
1037
1038
1039
1040
        x = F.pad(
            x,
            pad=(self._left_padding, self._right_padding, 0, 0),
        )
1041
1042
1043
1044
1045
1046
1047
1048
1049
        x = super().forward(x)
        return x


class NemoConvSubsampling(torch.nn.Module):
    """Convlutional subsampling module, taken from NeMo ASR
    (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
    34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)

1050
1051
    Striding Subsampling: "Speech-Transformer: A No-Recurrence
    Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
1052
1053
1054
    et al. (https://ieeexplore.ieee.org/document/8462506)


1055
    Compared with the EncoderConv2D (`input_layer: custom`), this is a
1056
1057
1058
1059
    much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
    Moreover, depthwise convolutions are used to reduce FLOPs, but the first
      layer is kept as a regular convolution so as not to degrade accuracy.

1060
    `Striding` and `dw_striding` are the same except that the latter uses
1061
1062
1063
1064
1065
1066
1067
    depthwise convolutions after the first layer, whereas the former does not.

    Args:
        subsampling_factor (int): Time reduction factor
        feat_in (int): size of the input features
        feat_out (int): size of the output features
        subsampling (str): The subsampling technique, choose from
1068
            {"striding", "dw-striding", "striding_conv1d",
1069
            "dw_striding_conv1d"}
1070
        conv_channels (int): Number of channels for the convolution layers,
1071
                            default is 256.
1072
        subsampling_conv_chunking_factor (int): Input chunking factor which
1073
1074
1075
1076
1077
1078
1079
            can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
        activation (Module): activation function, default is nn.ReLU()
        is_causal (bool): whether to use causal Conv1/2D, where each step will
            have limited access to locations on its right or left
    """

    def __init__(
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        self,
        feat_in: int,
        feat_out: int,
        subsampling_factor: int = 4,
        subsampling: str = "dw_striding",
        conv_channels: int = 256,
        subsampling_conv_chunking_factor: int = 1,
        activation: torch.nn.Module = nn.ReLU(),  # noqa: B008
        is_causal: bool = False,
    ) -> None:
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        super().__init__()
        self._subsampling = subsampling
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log(subsampling_factor, 2))
        self.subsampling_factor = subsampling_factor
        self.is_causal = is_causal
        self.subsampling_causal_cond = subsampling in (
            "dw_striding",
            "striding",
            "striding_conv1d",
        )

1107
1108
1109
1110
1111
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
1112
            raise ValueError(
1113
                "subsampling_conv_chunking_factor should be -1, 1, or a power of 2"
1114
            )
1115
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142

        in_channels = 1
        layers = []

        if subsampling == "dw_striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            # Layer 1
            if self.is_causal:
                layers.append(
                    CausalConv2D(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=None,
1143
1144
                    )
                )
1145
1146
1147
1148
1149
1150
1151
1152
            else:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
1153
1154
                    )
                )
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                            groups=in_channels,
1168
1169
                        )
                    )
1170
1171
1172
1173
1174
1175
1176
1177
1178
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
1179
1180
                        )
                    )
1181
1182
1183
1184
1185
1186
1187
1188
1189

                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
1190
1191
                    )
                )
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
1218
1219
                        )
                    )
1220
1221
1222
1223
1224
1225
1226
1227
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
1228
1229
                        )
                    )
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv1D(
                            in_channels=in_channels,
1254
1255
1256
1257
1258
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 1
                                else conv_channels
                            ),
1259
1260
1261
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
1262
1263
                        )
                    )
1264
1265
1266
1267
                else:
                    layers.append(
                        torch.nn.Conv1d(
                            in_channels=in_channels,
1268
1269
1270
1271
1272
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 1
                                else conv_channels
                            ),
1273
1274
1275
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
1276
1277
                        )
                    )
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "dw_striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            self._left_padding = (self._kernel_size - 1) // 2
            self._right_padding = (self._kernel_size - 1) // 2

            # Layer 1
1292
1293
            layers.extend(
                [
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                        groups=in_channels,
                    ),
                    torch.nn.Conv1d(
                        in_channels=in_channels,
1304
1305
1306
                        out_channels=(
                            feat_out if self._sampling_num == 1 else conv_channels
                        ),
1307
1308
1309
1310
1311
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    ),
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
                ]
            )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                layers.extend(
                    [
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        ),
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 2
                                else conv_channels
                            ),
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            groups=1,
                        ),
                    ]
                )
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
                layers.append(activation)
                in_channels = conv_channels

        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        if subsampling in ["dw_striding", "striding"]:
            in_length = torch.tensor(feat_in, dtype=torch.float)
            out_length = calc_length(
                lengths=in_length,
                all_paddings=self._left_padding + self._right_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                ceil_mode=self._ceil_mode,
                repeat_num=self._sampling_num,
            )
1358
            self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
1359
1360
1361
1362
1363
1364
1365
1366
1367
            self.conv2d_subsampling = True
        elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
            self.out = None
            self.conv2d_subsampling = False
        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        self.conv = torch.nn.Sequential(*layers)

1368
    def get_sampling_frames(self) -> list[int]:
1369
1370
        return [1, self.subsampling_factor]

1371
    def get_streaming_cache_size(self) -> list[int]:
1372
1373
        return [0, self.subsampling_factor + 1]

1374
1375
1376
    def forward(
        self, x: Tensor, mask: Optional[Tensor]
    ) -> tuple[Tensor, Optional[Tensor]]:
1377
1378
1379
1380
        """
        Forward method for NeMo subsampling.

        Args:
1381
1382
            x: input tensor
            mask: input mask
1383
1384

        Returns:
1385
            x: Resulting tensor from subsampling (B, T //
1386
                time_reduction_factor, feat_out)
1387
            pad_mask: tensor of padded hidden state sequences (B, 1, T //
1388
1389
1390
1391
1392
                time_reduction_factor)
        """
        x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)

        # split inputs if chunking_factor is set
1393
        if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
1394
1395
1396
1397
1398
1399
            if self.subsampling_conv_chunking_factor == 1:
                # if subsampling_conv_chunking_factor is 1, we split only
                # if needed.
                # avoiding a bug / feature limiting indexing of tensors
                # to 2**31.
                # see https://github.com/pytorch/pytorch/issues/80020
1400
                x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
                need_to_split = torch.numel(x) > x_ceil
            else:
                # if subsampling_conv_chunking_factor > 1 we always split
                need_to_split = True

            if need_to_split:
                x, success = self.conv_split_by_batch(x)
                if not success:  # if unable to split by batch, try by channel
                    if self._subsampling == "dw_striding":
                        x = self.conv_split_by_channel(x)
                    else:
                        x = self.conv(x)  # try anyway
            else:
                x = self.conv(x)
        else:
            x = self.conv(x)

        # Flatten Channel and Frequency Axes
        if self.conv2d_subsampling:
            b, c, t, f = x.size()
            x = self.out(x.transpose(1, 2).reshape(b, t, -1))
        # Transpose to Channel Last mode
        else:
            x = x.transpose(1, 2)

        if mask is None:
            return x, None

        max_audio_length = x.shape[1]
        feature_lens = mask.sum(1)
        padding_length = torch.ceil(feature_lens / self.subsampling_factor)
        if self.is_causal and self.subsampling_causal_cond:
            feature_lens_remainder = feature_lens % self.subsampling_factor
            padding_length[feature_lens_remainder != 1] += 1
        pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
1436
1437
            padding_length.size(0), -1
        ) < padding_length.unsqueeze(1)
1438
1439
        return x, pad_mask.unsqueeze(1)

1440
    def reset_parameters(self) -> None:
1441
1442
1443
1444
1445
        # initialize weights
        if self._subsampling == "dw_striding":
            with torch.no_grad():
                # init conv
                scale = 1.0 / self._kernel_size
1446
                dw_max = (self._kernel_size**2) ** -0.5
1447
1448
1449
1450
1451
1452
                pw_max = self._conv_channels**-0.5

                torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
                torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)

                for idx in range(2, len(self.conv), 3):
1453
1454
1455
1456
                    torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
1457
1458
1459
1460

                # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
                # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
                # src/models/conformer_encoder.py#L487
1461
                fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
1462
1463
1464
                torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
                torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)

1465
    def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]:
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
        """Tries to split input by batch, run conv and concat results"""
        b, _, _, _ = x.size()
        if b == 1:  # can't split if batch size is 1
            return x, False

        if self.subsampling_conv_chunking_factor > 1:
            cf = self.subsampling_conv_chunking_factor
        else:
            # avoiding a bug / feature limiting indexing of tensors to 2**31
            # see https://github.com/pytorch/pytorch/issues/80020
            x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
            p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
            cf = 2**p

        new_batch_size = b // cf
        if new_batch_size == 0:  # input is too big
            return x, False

        return (
1485
1486
1487
            torch.cat(
                [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]
            ),
1488
1489
1490
            True,
        )

1491
    def conv_split_by_channel(self, x: Tensor) -> Tensor:
1492
        """For dw convs, tries to split input by time, run conv and concat
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
        results"""
        x = self.conv[0](x)  # full conv2D
        x = self.conv[1](x)  # activation

        for i in range(self._sampling_num - 1):
            _, c, t, _ = x.size()

            if self.subsampling_conv_chunking_factor > 1:
                cf = self.subsampling_conv_chunking_factor
            else:
                # avoiding a bug / feature limiting indexing of tensors
                # to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
                cf = 2**p

            new_c = int(c // cf)
            if new_c == 0:
                new_c = 1

            new_t = int(t // cf)
            if new_t == 0:
                new_t = 1

1517
1518
1519
            x = self.channel_chunked_conv(
                self.conv[i * 3 + 2], new_c, x
            )  # conv2D, depthwise
1520
1521
1522

            # splitting pointwise convs by time
            x = torch.cat(
1523
                [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)],
1524
1525
1526
1527
1528
                2,
            )  # conv2D, pointwise
            x = self.conv[i * 3 + 4](x)  # activation
        return x

1529
1530
1531
    def channel_chunked_conv(
        self, conv: torch.nn.Module, chunk_size: int, x: Tensor
    ) -> Tensor:
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
        """Performs channel chunked convolution"""

        ind = 0
        out_chunks = []
        for chunk in torch.split(x, chunk_size, 1):
            step = chunk.size()[1]

            if self.is_causal:
                chunk = nn.functional.pad(
                    chunk,
                    pad=(
                        self._kernel_size - 1,
                        self._stride - 1,
                        self._kernel_size - 1,
                        self._stride - 1,
                    ),
                )
                ch_out = nn.functional.conv2d(
                    chunk,
1551
1552
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
1553
1554
1555
1556
1557
1558
1559
                    stride=self._stride,
                    padding=0,
                    groups=step,
                )
            else:
                ch_out = nn.functional.conv2d(
                    chunk,
1560
1561
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
                    stride=self._stride,
                    padding=self._left_padding,
                    groups=step,
                )
            out_chunks.append(ch_out)
            ind += step

        return torch.cat(out_chunks, 1)

    def change_subsampling_conv_chunking_factor(
1572
1573
1574
1575
1576
1577
1578
        self, subsampling_conv_chunking_factor: int
    ) -> None:
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
1579
            raise ValueError(
1580
                "subsampling_conv_chunking_factor should be -1, 1, or a power of 2"
1581
1582
1583
1584
            )
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor


1585
1586
1587
1588
1589
1590
1591
1592
def calc_length(
    lengths: Tensor,
    all_paddings: int,
    kernel_size: int,
    stride: int,
    ceil_mode: bool,
    repeat_num: int = 1,
) -> Tensor:
1593
    """Calculates the output length of a Tensor passed through a convolution or
1594
    max pooling layer"""
1595
1596
1597
    add_pad: float = all_paddings - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
1598
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
1599
1600
1601
1602
1603
1604
1605
1606
        lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
    return lengths.to(dtype=torch.int)


####  multihead attention starts here
class AttModule(nn.Module):
    """Attention abstraction module"""

1607
    def __init__(self) -> None:
1608
1609
1610
        super().__init__()
        self.export_mode = False

1611
    def set_export(self, mode: bool = True) -> None:
1612
1613
1614
1615
1616
1617
1618
1619
1620
        """set the export mode"""
        self.export_mode = mode

    def forward(
        self,
        x: Tensor,
        memory: Optional[Tensor] = None,
        pos_emb: Optional[Tensor] = None,
        att_mask: Optional[Tensor] = None,
1621
    ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
1622
1623
1624
        """AttModule forward

        Args:
1625
1626
1627
1628
            x: input tensor.
            memory: memory tensor.
            pos_emb: positional encoder embedding.
            att_mask: attention mask tensor.
1629
1630
1631
1632
        """
        return x, memory, pos_emb, att_mask


1633
class AttBlock(BlockBase, AttModule):
1634
1635
    """Attention Block module to support both Attention and Block module."""

1636
    def memory_dims(self, max_len: bool = False) -> tuple[int, int]:
1637
1638
1639
1640
1641
        """memory dimensions"""
        return (1, self.input_size)


def masked_softmax(
1642
    scores: Tensor,
1643
    mask: Optional[Tensor],
1644
) -> Tensor:
1645
1646
1647
1648
    if mask is not None:
        mask = mask.unsqueeze(1).eq(0)  # (batch, 1, time1, time2)
        scores = scores.masked_fill(mask, -torch.inf)
        attn = torch.softmax(scores, dim=-1).masked_fill(
1649
1650
            mask, 0.0
        )  # (batch, head, time1, time2)
1651
1652
1653
1654
1655
1656
    else:
        attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
    return attn


class MultiHeadedAttention(nn.Module):
1657
    """Multi-Head Attention layer with optional relative position embedding
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
    and GLU.

    Args:
        n_head: int
            the number of heads.
        n_feat: int
            input size features.
        dropout_rate: float
            dropout rate.
        attention_inner_dim: int, optional
            the attention dimension used in the class,
            it can be different from the input dimension n_feat.
            default: -1 (equal to n_feat).
        use_pt_scaled_dot_product_attention: bool, optional
            if set True, use pytorch scaled dot product attention in training.
1673
1674
            NOTE: this will NOT be used in ONNX decoding due to a lack of
            support.  In that case, we use the original attention
1675
1676
1677
            implementation, which shows no regression.
            default: False.
        n_value: int, optional
1678
            if set to values other than -1, use a different dimension for
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
            value. With the default value (i.e. -1), it is backward compatible.
        group_size: int, optional. must divide `n_head`
            if group_size > 1:       GQA
            if group_size = 1:       MHA
            if group_size = n_head:  MQA
    """

    inv_sqrt_d_k: torch.jit.Final[float]
    h: torch.jit.Final[int]
    h_k: torch.jit.Final[int]
    g: torch.jit.Final[int]

    def __init__(
        self,
1693
1694
1695
1696
1697
1698
1699
1700
        n_head: int,
        n_feat: int,
        dropout_rate: float,
        attention_inner_dim: int = -1,
        glu_type: str = "swish",
        bias_in_glu: bool = True,
        use_pt_scaled_dot_product_attention: bool = False,
        n_value: int = -1,
1701
        group_size: int = 1,
1702
    ) -> None:
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
        super().__init__()
        if n_value == -1:
            n_value = n_feat
        if attention_inner_dim == -1:
            attention_inner_dim = n_feat
        assert attention_inner_dim % n_head == 0

        # We assume d_v always equals d_k
        self.d_k = attention_inner_dim // n_head
        self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
        self.h = n_head
        assert n_head % group_size == 0, "group_size must divide n_head"
        self.g = group_size
        self.h_k = n_head // group_size

        self.linear_q = nn.Linear(n_feat, attention_inner_dim)
        self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
        self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
        self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)

        self.attn = torch.jit.Attribute(None, Optional[Tensor])
        self.dropout = nn.Dropout(p=dropout_rate)
        self.dropout_rate = dropout_rate
1726
        self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743

        if use_pt_scaled_dot_product_attention and group_size > 1:
            raise ValueError("Cannot use PT Scaled Attention with GQA")

        # Torchscript eager quantization.  Note that these functions below are
        # NOOPs and have very little impact on performance unless quantization
        # is enabled.
        self.quant_q = torch.ao.quantization.QuantStub()
        self.quant_x = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.ffunc = torch.ao.nn.quantized.FloatFunctional()

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
1744
1745
        pos_k: Optional[Tensor],
        pos_v: Optional[Tensor],
1746
1747
        mask: Optional[Tensor],
        relative_attention_bias: Optional[Tensor] = None,
1748
    ) -> Tensor:
1749
1750
1751
        """Compute 'Scaled Dot Product Attention'.

        Args:
1752
1753
1754
1755
1756
1757
            query: query tensor (batch, time1, size)
            key: key tensor (batch, time2, size)
            value: value tensor (batch, time1, size)
            pos_k: key tensor used for relative positional embedding.
            pos_v: value tensor used for relative positional embedding.
            mask: mask tensor (batch, time1, time2)
1758
            relative_attention_bias: bias added to attention logits w.r.t.
1759
                relative positions
1760
1761
1762
1763
                (1, n_head, time1, time2)
        """
        n_batch = query.size(0)

1764
1765
        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)  # (b, t, d)
        k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k)  # (b, t, d)
1766
        v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
1767
1768
1769
1770
1771
        q = (
            q.transpose(1, 2)
            if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
            else q.transpose(1, 2) * self.inv_sqrt_d_k
        )
1772
1773
1774
        k = k.transpose(1, 2)  # (batch, head_k, time2, d_k)
        v = v.transpose(1, 2)  # (batch, head_k, time2, d_k)

1775
        if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
            attn_mask = None
            if mask is not None:
                mask = mask.unsqueeze(1)
                if relative_attention_bias is not None:
                    attn_mask = mask + relative_attention_bias
                else:
                    attn_mask = mask
                if mask.dtype != q.dtype:
                    attn_mask = attn_mask.to(q.dtype)

1786
1787
            with torch.nn.attention.sdpa_kernel(
                [
cyyever's avatar
cyyever committed
1788
1789
1790
1791
                    torch.nn.attention.SDPBackend.FLASH_ATTENTION,
                    torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
                    torch.nn.attention.SDPBackend.MATH,
                    torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1792
1793
                ]
            ):
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
                x = torch.nn.functional.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=attn_mask,
                    dropout_p=self.dropout_rate,
                )
        else:
            if self.h != self.h_k:
                q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
                A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
            else:
                A = torch.matmul(q, k.transpose(-2, -1))
            if pos_k is not None:
                if self.h != self.h_k:
                    B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
                else:
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
                    reshape_q = (
                        q.contiguous()
                        .view(n_batch * self.h, -1, self.d_k)
                        .transpose(0, 1)
                    )  # (t1,nh,dk)
                    B = torch.matmul(
                        reshape_q, pos_k.transpose(-2, -1)
                    )  # pos_k: (t1,dk,t2)
                    B = B.transpose(0, 1).view(
                        n_batch, self.h, pos_k.size(0), pos_k.size(1)
                    )
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
                scores = A + B
            else:
                scores = A

            if relative_attention_bias is not None:
                scores = scores + relative_attention_bias

            attn = masked_softmax(scores, mask)  # (batch, head, time1, time2)

            self.attn = attn

            p_attn = self.dropout(attn)
1834
            x = torch.matmul(p_attn.to(v.dtype), v)  # (batch, head, time1, d_k)
1835
            if pos_v is not None:
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
                reshape_attn = (
                    p_attn.contiguous()
                    .view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
                    .transpose(0, 1)
                )  # (t1, bh, t2)

                attn_v = (
                    torch.matmul(reshape_attn, pos_v)
                    .transpose(0, 1)
                    .contiguous()
                    .view(n_batch, self.h, pos_v.size(0), self.d_k)
                )
1848
                x = x + attn_v
1849
1850
1851
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
        )  # (batch, time1, d_model)
1852
1853
1854
1855
1856
1857
1858
1859

        return self.linear_out(x)  # (batch, time1, d_model)


class MultiSequential(torch.nn.Sequential):
    """Multi-input multi-output torch.nn.Sequential"""

    @torch.jit.ignore
1860
    def forward(self, *args) -> tuple:
1861
1862
1863
1864
1865
1866
        """Forward method implementation."""
        for m in self:
            args = m(*args)
        return args


1867
def get_offset(input_layer: str, time_reduction: int) -> int:
1868
    """Get an offset. We will use the offset for determining #frames of a
1869
1870
1871
    subsampled feature.

    Args:
1872
1873
        input_layer: Type of an input layer
        time_reduction: time reduction factor for downsampling a feature
1874
1875
1876
1877
1878
    Returns:
        int: offset
    """
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
        return 3
1879
    if input_layer in ("conv2d",) and time_reduction == 6:
1880
1881
1882
1883
1884
1885
        return 1
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
        return 7
    return 0


1886
def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor:
1887
    """
1888
1889
    For a given tensor with shape of (N, T, D), if sequence length T is
    longer than max_seq_len, this function unfold it to a
1890
1891
    (NT', max_seq_len, D) where T' is T // max_seq_len.
    Args:
1892
1893
        xs_pad: input tensor with shape (N, T, D)
        max_seq_len: maximum sequence length
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
    """
    _, _, D = xs_pad.shape
    xs_pad = xs_pad.transpose(-1, -2)  # convert to N, D, T
    # N x D x 1 x T => N x (D x max_seq_len) x T'
    xs_pad = F.unfold(
        xs_pad[..., None, :],
        kernel_size=(1, max_seq_len),
        stride=(1, max_seq_len),
    )
    new_bsz, _, slen = xs_pad.shape
    # N x D x max_seq_len x T'
    xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
    # N x T' x max_seq_len x D
    xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
    # NT' x max_seq_len x D
    xs_pad = xs_pad.view(-1, max_seq_len, D)
    return xs_pad