phi4mm_utils.py 64.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import math

import torch
import torch.nn.functional as F
from torch import Tensor, nn


15
class BlockBase(nn.Module):
16
17
    """Block abstract module"""

18
    def __init__(self, input_size: int, output_size: int) -> None:
19
20
21
22
23
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size


24
def get_activation(name: str = "relu") -> torch.nn.Module:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    """Select an activation function by name

    Args:
        name: str
            activation function name,
            one of ["relu", "gelu", "swish", "sigmoid"],
            default "relu".
    """
    name = name.lower()
    if name == "relu":
        return nn.ReLU(inplace=True)
    if name == "gelu":
        return nn.GELU()
    if name == "swish":
39
        return nn.SiLU()
40
    if name == "sigmoid":
41
42
43
44
45
        return nn.Sigmoid()
    if name == "identity":
        return nn.Identity()

    raise NotImplementedError(name)
46
47


48
49
50
def adaptive_enc_mask(
    x_len: int, chunk_start_idx: list[int], left_window: int = 0, right_window: int = 0
) -> torch.Tensor:
51
52
53
    """
    The function is very important for Transformer Transducer Streaming mode
    Args:
54
        x_len: sequence length
55
        chunk_start_idx: first idx of each chunk, such as [0,18,36,48].
56
        It also supports adaptive chunk size [0,10,15,45]
57
        left_window: how many left chunks can be seen
58
        right_window: how many right chunks can be seen. It is used for
59
60
61
62
63
64
65
66
67
68
69
70
        chunk overlap model.
        Returns:
            mask (torch.Tensor): a mask tensor for streaming model
            Torch 1.0.1
            tensor([[1., 1., 0., 0.],
                    [0., 1., 1., 0.],
                    [0., 0., 1., 1.]])
            Torch 1.4.1
            tensor([[True., True., False., False.],
                    [False., True., True., False.],
                    [False., False., True., True.]])
    """
71
72
73
    chunk_start_idx = torch.Tensor(
        chunk_start_idx
    ).long()  # first idx of each chunk, such as [0,18,36,48].
74
    start_pad = torch.nn.functional.pad(
75
76
        chunk_start_idx, (1, 0)
    )  # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
77
78
79
    end_pad = torch.nn.functional.pad(
        chunk_start_idx, (0, 1), value=x_len
    )  # append x_len to the end, so it becomes [0,18,36,48, x_len]
80
81
82
83
    seq_range = torch.arange(0, x_len).unsqueeze(-1)  # seq_range size: [x_len, 1]
    idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[
        :, 1
    ]  # idx size: [x_len]
84
    # boundary = end_pad[idx]  # boundary size: [x_len]
85
86
87
    seq_range_expand = (
        torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
    )  # seq_range_expand size [x_len, x_len]
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    idx_left = idx - left_window
    idx_left[idx_left < 0] = 0
    boundary_left = start_pad[idx_left]
    mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
    idx_right = idx + right_window
    idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
    boundary_right = end_pad[idx_right]
    mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
    return mask_left & mask_right


class GLU(nn.Module):
    """Implement Gated Linear Unit (GLU) module"""

    def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
        super().__init__()
104

105
        self.dim = dim
106
        self.act_fn = get_activation(act_name)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    def forward(self, x: Tensor) -> Tensor:
        """GLU forward
        Apply Swish function on the first half of input matrices
        with sigmoid of the second half.

        Args:
            x: torch.Tensor
                Input.

        """
        half_x, gate = x.chunk(2, dim=self.dim)
        return half_x * self.act_fn(gate)


# TODO: Abdel, this can be improved using GLU module
class GLUPointWiseConv(nn.Module):
    """GLUPointWiseConv module
    used for conformer architecture,
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        output_dim: int
            output channel size.
        kernel_size: int
            kernel size
        glu_type: str, optional
            activation function one of
             ["sigmoid", "relu", "gelu"]
              default "sigmoid".
        bias_in_glu: bool, optional
            use addtive bias in glu
        causal: bool, optional
            if set to True, padding is set to the half of
             kernel size, ie, convolution can't see future frames.
              default False.

    """

    def __init__(
        self,
151
152
153
154
155
156
157
        input_dim: int,
        output_dim: int,
        kernel_size: int,
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        causal: bool = False,
    ) -> None:
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        super().__init__()

        self.glu_type = glu_type
        self.output_dim = output_dim
        self.bias_in_glu = bias_in_glu
        if causal:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim,
                output_dim * 2,
                kernel_size,
                1,
                padding=(kernel_size - 1),
            )
        else:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim,
                output_dim * 2,
                kernel_size,
                1,
                padding=(kernel_size - 1) // 2,
            )

180
        self.glu_act = get_activation(glu_type)
181
182
183
184
185

        if bias_in_glu:
            self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
            self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))

186
    def forward(self, x: Tensor) -> Tensor:
187
188
        """
        Args:
189
            x: input tensor
190
191
192
193
194
195
196
197
        """
        # to be consistent with GLULinear, we assume the input always has the
        # #channel (#dim) in the last dimension of the tensor, so need to
        # switch the dimension first for 1D-Conv case
        x = x.permute([0, 2, 1])
        x = self.ext_pw_conv_1d(x)
        if self.glu_type == "bilinear":
            if self.bias_in_glu:
198
199
200
                x = (x[:, 0 : self.output_dim, :] + self.b1) * (
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
201
            else:
202
203
204
205
                x = (
                    (x[:, 0 : self.output_dim, :])
                    * (x[:, self.output_dim : self.output_dim * 2, :])
                )
206
207
        else:
            if self.bias_in_glu:
208
209
210
                x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
211
            else:
212
213
214
                x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :]
                )
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        x = x.permute([0, 2, 1])
        return x


class DepthWiseSeperableConv1d(nn.Module):
    """DepthWiseSeperableConv1d module used in Convnet module
    for the conformer, for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        depthwise_seperable_out_channel: int
229
            if set different to 0, the number of
230
231
             depthwise_seperable_out_channel will be used as a channel_out
             of the second conv1d layer.
232
             otherwise, it equals to 0, the second conv1d layer is skipped.
233
234
235
236
237
238
239
240
241
242
243
244
245
        kernel_size: int
            kernel_size
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
            will be used to compute the hidden channels of the Conv1D.
        padding: int, optional
            padding for the conv1d,
             default: 0.

    """

    def __init__(
        self,
246
247
248
249
250
251
        input_dim: int,
        depthwise_seperable_out_channel: int,
        kernel_size: int,
        depthwise_multiplier: int,
        padding: int = 0,
    ) -> None:
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        super().__init__()

        self.dw_conv = nn.Conv1d(
            input_dim,
            input_dim * depthwise_multiplier,
            kernel_size,
            1,
            padding=padding,
            groups=input_dim,
        )

        if depthwise_seperable_out_channel != 0:
            self.pw_conv = nn.Conv1d(
                input_dim * depthwise_multiplier,
                depthwise_seperable_out_channel,
                1,
                1,
                0,
            )
        else:
            self.pw_conv = nn.Identity()
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel

275
    def forward(self, x: Tensor) -> Tensor:
276
277
278
        """

        Args:
279
            x: input tensor
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        """
        x = self.dw_conv(x)
        if self.depthwise_seperable_out_channel != 0:
            x = self.pw_conv(x)
        return x


class ConvModule(nn.Module):
    """ConvModule Module for the conformer block.
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        ext_pw_out_channel: int
            if > 0, ext_pw_out_channel is a dim channel size
             for the last pointwise conv after swish activation.
        depthwise_seperable_out_channel: int
299
            if set different to 0, the number of
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
             depthwise_seperable_out_channel
             will be used as a channel_out of the second conv1d layer.
             otherwise, it equal to 0, the second conv1d layer is skipped.
        ext_pw_kernel_size: int
            kernel size of the conv pointwise of the conformer.
        kernel_size: int
            kernel size.
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
             will be used to compute the hidden channels of the Conv1D.
        dropout_rate: float
            dropout rate.
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation.
            default False
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
        chunk_size: int, optional
            chunk size for cnn. default 18
        activation: str, optional
            activation function used in ConvModule,
            default: "relu".
        glu_type: str, optional
            activation function used for the glu,
            default: "sigmoid".
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU.
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        export: bool, optional,
            if set to True, padding is equal to 0.  This is for inference,
             or onnx export.  Typically this is set by the export program or
             the decoder program, and it isn't present in your config file.
             default False
    """

    def __init__(
        self,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        input_dim: int,
        ext_pw_out_channel: int,
        depthwise_seperable_out_channel: int,
        ext_pw_kernel_size: int,
        kernel_size: int,
        depthwise_multiplier: int,
        dropout_rate: float,
        causal: bool = False,
        batch_norm: bool = False,
        chunk_se: int = 0,
        chunk_size: int = 18,
        activation: str = "relu",
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        linear_glu_in_convm: bool = False,
        export: bool = False,
    ) -> None:
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.input_dim = input_dim
        self.ext_pw_out_channel = ext_pw_out_channel
        self.ext_pw_kernel_size = ext_pw_kernel_size
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
        self.glu_type = glu_type
        self.bias_in_glu = bias_in_glu
        self.linear_glu_in_convm = linear_glu_in_convm
        self.causal = causal

        self._add_ext_pw_layer()

        self.batch_norm = batch_norm
        self.kernel_size = kernel_size

        if batch_norm:
            self.bn_layer = nn.BatchNorm1d(input_dim)

        self.act = get_activation(activation)
        self.dropout = nn.Dropout(dropout_rate)
        self.export = export

        if causal:
            padding = 0 if export else kernel_size - 1
        else:
            padding = (kernel_size - 1) // 2

        self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
            input_dim,
            depthwise_seperable_out_channel,
            kernel_size,
            depthwise_multiplier,
            padding=padding,
        )

        if depthwise_seperable_out_channel != 0:
            if input_dim != depthwise_seperable_out_channel:
403
                self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
404
405
        else:
            if depthwise_multiplier != 1:
406
                self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)
407

408
    def _add_ext_pw_layer(self) -> None:
409
410
411
412
413
414
        """
        This function is an extension of __init__ function
        and dedicated to the convolution module creation
        of the conformer.
        """
        self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
415
416
            nn.Identity()
        )  # jit hacks.
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        self.squeeze_excitation = nn.Identity()  # jit.
        self.apply_ln1 = self.fix_len1 = False  # jit.

        if self.ext_pw_out_channel != 0:
            if self.causal:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1),
                )
                if self.ext_pw_kernel_size > 1:
                    self.fix_len1 = True
                else:
                    self.fix_len1 = False
            else:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1) // 2,
                )
                self.fix_len1 = False

            if self.linear_glu_in_convm:
                self.glu = GLULinear(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.glu_type,
                    self.bias_in_glu,
                )
            else:
                self.glu = GLUPointWiseConv(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    self.glu_type,
                    self.bias_in_glu,
                    self.causal,
                )

            if self.input_dim != self.ext_pw_out_channel:
                self.apply_ln1 = True
                self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
            else:
                self.apply_ln1 = False
        else:
            self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
            self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))

469
    def forward(self, x: Tensor) -> Tensor:
470
471
472
        """ConvModule Forward.

        Args:
473
            x: input tensor.
474
475
476
477
478
479
        """
        x = self.layer_norm(x)

        if self.ext_pw_out_channel != 0:
            x = self.glu(x)
            if self.causal and self.ext_pw_kernel_size > 1:
480
                x = x[:, : -(self.ext_pw_kernel_size - 1), :]
481
482
483
484
485
486
487
488
489
490
491
            if self.apply_ln1:
                x = self.ln1(x)
        else:
            x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
            x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
            x = x_0 + x_1

        x = x.permute([0, 2, 1])

        x = self.dw_sep_conv_1d(x)
        if self.causal and self.kernel_size > 1:
492
            x = x[:, :, : -(self.kernel_size - 1)]
493
494
495
496
497
498
499
500
501
502
503
        if hasattr(self, "ln2"):
            x = x.permute([0, 2, 1])
            x = self.ln2(x)
            x = x.permute([0, 2, 1])
        if self.batch_norm:
            x = self.bn_layer(x)
        x = self.act(x)

        if self.ext_pw_out_channel != 0:
            x = self.ext_pw_conv_1d(x)
            if self.fix_len1:
504
                x = x[:, :, : -(self.ext_pw_kernel_size - 1)]
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537

            if self.apply_ln1:
                x = x.permute([0, 2, 1])
                x = self.ln1(x)
                x = x.permute([0, 2, 1])

            x = x.permute([0, 2, 1])
        else:
            x = x.unsqueeze(1).permute([0, 1, 3, 2])
            x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
            x = x.squeeze(1)

        x = self.dropout(x)
        return x


class GLULinear(nn.Module):
    """Linear + GLU module

    Args:
        input_dim: int
            input size
        output_dim: int
            output size.
        glu_type:
            activation function name used in glu module.
            default "sigmoid" (swish function).
        bias_in_glu: bool, optional
            If True, the addtive bias is added. Default False.
    """

    def __init__(
        self,
538
539
540
541
542
        input_dim: int,
        output_dim: int,
        glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
    ) -> None:
543
544
545
546
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
        self.glu_act = GLU(-1, glu_type)

547
    def forward(self, x: Tensor) -> Tensor:
548
549
550
        """GLULinear forward

        Args:
551
            x: input tensor.
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        """
        x = self.linear(x)
        return self.glu_act(x)


class FeedForward(nn.Module):
    """FeedForward Module.
    For more details see Conformer paper:
        https://arxiv.org/pdf/2005.08100.pdf

    Args:
        d_model: int
            input size.
        d_inner: int
            output size.
        dropout_rate: float,
            dropout rate.
        activation: str,
            activation function name,
            one of ["relu", "swish", "sigmoid"],
            sigmoid activation is only used with "glu_in_fnn=True",
            default "sigmoid".
        bias_in_glu: bool, optional
    """

    def __init__(
        self,
579
580
581
582
583
584
        d_model: int,
        d_inner: int,
        dropout_rate: float,
        activation: str = "sigmoid",
        bias_in_glu: bool = True,
    ) -> None:
585
586
587
588
589
590
591
592
593
594
595
596
597
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner

        self.layer_norm = nn.LayerNorm(d_model)
        module = GLULinear(d_model, d_inner, activation, bias_in_glu)
        self.net = nn.Sequential(
            module,
            nn.Dropout(dropout_rate),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout_rate),
        )

598
    def forward(self, x: Tensor) -> Tensor:
599
600
601
        """FeedForward forward function.

        Args:
602
            x: input tensor.
603
604
605
606
607
608
609
610
        """
        out = self.net(self.layer_norm(x))

        return out


#### positional encoding starts here
def _pre_hook(
611
612
613
614
615
616
617
618
    state_dict: dict,
    prefix: str,
    local_metadata: dict,
    strict: bool,
    missing_keys: list[str],
    unexpected_keys: list[str],
    error_msgs: list[str],
) -> None:
619
620
621
622
    """Perform pre-hook in load_state_dict for backward compatibility.

    Note:
        We saved self.pe until v.0.5.2 but we have omitted it later.
623
        Therefore, we remove the item "pe" from `state_dict` for backward
624
625
626
627
628
629
630
631
632
633
        compatibility.

    """
    k = prefix + "pe"
    if k in state_dict:
        state_dict.pop(k)


class T5RelativeAttentionLogitBias(nn.Module):
    """
634
    This module implements the relative position bias described in Section
635
636
637
638
639
640
641
    2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf

    The Huggingface implementation is used as a reference
    https://github.com/huggingface/transformers/blob/v4.30.0/src/
    transformers/models/t5/modeling_t5.py#L435

    Modifies attention as Q*K^T + B, where B is a learned scalar bias based
642
    on relative position of the query and key. It is HxNxN, where H is the
643
644
645
    number of heads, N is the sequence length.

    I've made these modifications to the original T5 bias:
646
647
    - Skipping of the bucketing step. Original T5 bias converted rel
      position distances into logarithmically increasing buckets. This is
648
      supposed to help with length generalization.
649
650
    - I just directly use rel position index as bias values, as we don't
      need length generalization (40s max is good enough for ASR encoder),
651
      and it keeps ONNX export simple.
652
653
    - I've also extended it so that biases can be asymmetric, the default
      implementation treats L->R and R->L the same. Asymmetric was found to
654
655
656
657
658
659
660
      yield better results in my experiments.

    Args:
        num_heads: int
            Number of attention heads
        num_buckets: int
            Number of buckets to use for relative attention bias. This is the
661
            size of the learnable bias parameter. Bucketing is not yet
662
663
664
            supported, so this defaults to -1 which means no bucketing is
            used (max_distance determines size of bias param).
        max_distance: int
665
666
667
668
            Maximum distance to use for relative attention bias. With
            num_buckets=-1, this directly controls the max size of the bias
            parameter. When num_buckets > 0 is supported, this will control
            the maximum distance for logarithmic bucketing after which all
669
670
671
            positions are in the same bucket.
        symmetric: bool
            Whether to use symmetric or asymmetric biases. symmetric=False uses
672
            2x number of bias params to distinguish L->R from R->L. This was
673
674
675
            found to be better for the encoder.
    """

676
677
678
679
680
681
682
    def __init__(
        self,
        num_heads: int,
        num_buckets: int = -1,
        max_distance: int = 1000,
        symmetric: bool = False,
    ) -> None:
683
684
685
686
687
688
689
690
691
692
        super().__init__()
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.symmetric = symmetric
        self._skip_bucketing = self.num_buckets < 0
        if self._skip_bucketing:
            self.num_buckets = max_distance
        else:
            raise NotImplementedError(
693
694
                "T5 attention bias with bucketed positions is not yet tested"
            )
695
696
697
698
        if not self.symmetric:
            self.num_buckets *= 2
        self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)

699
    def forward(self, x: Tensor) -> Tensor:
700
701
        # instantiate bias compatible with shape of x
        maxpos = x.size(1)
702
703
704
705
706
707
        context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
            :, None
        ]
        memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
            None, :
        ]
708
709
710
711
        relative_position = memory_position - context_position
        # clipping to a maximum distance using ops that play well with ONNX
        # export
        relative_position = relative_position.masked_fill(
712
713
            relative_position < -self.max_distance, -self.max_distance
        )
714
        relative_position = relative_position.masked_fill(
715
716
            relative_position > self.max_distance - 1, self.max_distance - 1
        )
717
718
719
720
721
722
723
724
725
726
727
728

        # mapping from relative position to index in the bias parameter
        if self._skip_bucketing:
            bias_idx = relative_position
        else:
            bias_idx = self._bucket_relative_position(relative_position)
        if self.symmetric:
            bias_idx = bias_idx.abs()
        else:
            bias_idx += self.num_buckets // 2

        t5_rel_att_bias = self.bias_values(bias_idx)  # [L, L, H]
729
        t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0)  # [1, H, L, L]
730
731
732

        return t5_rel_att_bias

733
    def _bucket_relative_position(self, relative_position: Tensor) -> Tensor:
734
735
736
737
738
739
740
        # this is a placeholder (isn't tested, likely buggy) using HuggingFace
        # implem as a reference this also needs to be extended to support
        # asymmetric +/- ve positions
        relative_buckets = 0
        if not self.causal:
            self.num_buckets //= 2
            relative_buckets += (relative_position > 0).to(
741
742
                torch.long
            ) * self.num_buckets
743
744
            relative_position = torch.abs(relative_position)
        else:
745
746
747
            relative_position = -torch.min(
                relative_position, torch.zeros_like(relative_position)
            )
748
749
750
751
752
753
754
755
756
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = self.num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in
        # positions up to max_distance
        relative_position_if_large = max_exact + (
757
758
759
760
            torch.log(relative_position.float() / max_exact)
            / math.log(self.max_distance / max_exact)
            * (self.num_buckets - max_exact)
        ).to(torch.long)
761
762
763
764
765
        relative_position_if_large = torch.min(
            relative_position_if_large,
            torch.full_like(relative_position_if_large, self.num_buckets - 1),
        )

766
767
768
        relative_buckets += torch.where(
            is_small, relative_position, relative_position_if_large
        )
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        return relative_buckets


class AbsolutePositionalEncoding(nn.Module):
    """Absolute Positional encoding module.
    This module implement Absolute sinusoidal positional encoding
    from: https://arxiv.org/pdf/1706.03762.pdf

    Args:
        d_model: int
            Input embedding size.
        dropout_rate: float
            dropout rate
        max_len: int, optional
            Maximum input length sequence, Default 5000

    """

787
    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
788
789
790
791
792
793
794
795
796
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.xscale = math.sqrt(self.d_model)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
        self._register_load_state_dict_pre_hook(_pre_hook)

797
    def extend_pe(self, x: torch.Tensor) -> None:
798
799
800
        """Reset the positional encodings.

        Args:
801
            x: input tensor
802
803
804
805
806
807
808
809
        """
        if self.pe is not None and self.pe.size(1) >= x.size(1):
            if self.pe.dtype != x.dtype or self.pe.device != x.device:
                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
            return
        pe = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
810
811
812
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.d_model)
        )
813
814
815
816
817
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

818
    def forward(self, x: torch.Tensor) -> torch.Tensor:
819
820
821
        """Add positional encoding.

        Args:
822
            x: Input tensor. shape is (batch, time, ...)
823
824

        Returns:
825
            Encoded tensor. Its shape is (batch, time, ...)
826
827
828

        """
        self.extend_pe(x)
829
        x = x * self.xscale + self.pe[:, : x.size(1)]
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        return self.dropout(x)


#### forward embedding layers starts here
class MeanVarianceNormLayer(nn.Module):
    """Mean/variance normalization layer.

    Will subtract mean and multiply input by inverted standard deviation.
    Typically used as a very first layer in a model.

    Args:
        input_size: int
            layer input size.
    """

845
    def __init__(self, input_size: int) -> None:
846
847
        super().__init__()
        self.input_size = input_size
848
849
        self.global_mean = nn.Parameter(torch.zeros(input_size))
        self.global_invstd = nn.Parameter(torch.ones(input_size))
850
851
852
853
854

    def forward(self, input_: Tensor) -> Tensor:
        """MeanVarianceNormLayer Forward

        Args:
855
            input_: input tensor.
856
857
858
859
860
861
862
863
864
865
        """
        return (input_ - self.global_mean) * self.global_invstd


class CausalConv1D(nn.Conv1d):
    """
    A causal version of nn.Conv1d where each step would have limited access to
    locations on its right or left
    All arguments are the same as nn.Conv1d except padding.

866
    If padding is set None, then paddings are set automatically to make it a
867
868
    causal convolution where each location would not see any steps on its right.

869
    If padding is set as a list (size of 2), then padding[0] would be used as
870
871
872
    left padding and padding[1] as right padding.
    It would make it possible to control the number of steps to be accessible
    on the right and left.
873
    This mode is not supported when stride > 1. padding[0]+padding[1] should
874
875
876
877
878
879
880
881
882
    be equal to (kernel_size - 1).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
883
        padding: str | int = 0,
884
885
886
887
888
889
890
891
892
893
894
895
896
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        self.cache_drop_size = None
        if padding is None:
            self._left_padding = kernel_size - 1
            self._right_padding = stride - 1
        else:
            if stride != 1 and padding != kernel_size - 1:
897
                raise ValueError("No striding allowed for non-symmetric convolutions!")
898
899
900
            if isinstance(padding, int):
                self._left_padding = padding
                self._right_padding = padding
901
902
903
904
905
            elif (
                isinstance(padding, list)
                and len(padding) == 2
                and padding[0] + padding[1] == kernel_size - 1
            ):
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
                self._left_padding = padding[0]
                self._right_padding = padding[1]
            else:
                raise ValueError(f"Invalid padding param: {padding}!")

        self._max_cache_len = self._left_padding

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )

927
    def update_cache(
928
929
        self, x: Tensor, cache: Tensor | None = None
    ) -> tuple[Tensor, Tensor | None]:
930
931
932
933
934
935
936
        if cache is None:
            new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
            next_cache = cache
        else:
            new_x = F.pad(x, pad=(0, self._right_padding))
            new_x = torch.cat([cache, new_x], dim=-1)
            if self.cache_drop_size > 0:
937
                next_cache = new_x[:, :, : -self.cache_drop_size]
938
939
            else:
                next_cache = new_x
940
            next_cache = next_cache[:, :, -cache.size(-1) :]
941
942
        return new_x, next_cache

943
    def forward(
944
945
        self, x: Tensor, cache: Tensor | None = None
    ) -> Tensor | tuple[Tensor, Tensor | None]:
946
947
948
949
950
951
952
953
954
955
956
957
        x, cache = self.update_cache(x, cache=cache)
        x = super().forward(x)
        if cache is None:
            return x
        else:
            return x, cache


class CausalConv2D(nn.Conv2d):
    """
    A causal version of nn.Conv2d where each location in the 2D matrix would
    have no access to locations on its right or down
958
    All arguments are the same as nn.Conv2d except padding which should be
959
960
961
962
963
964
965
966
967
    set as None
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
968
        padding: str | int = 0,
969
970
971
972
973
974
975
976
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        if padding is not None:
977
            raise ValueError("Argument padding should be set to None for CausalConv2D.")
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
        self._left_padding = kernel_size - 1
        self._right_padding = stride - 1

        padding = 0
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype,
        )

    def forward(
        self,
998
999
        x: Tensor,
    ) -> Tensor:
1000
1001
1002
1003
        x = F.pad(
            x,
            pad=(self._left_padding, self._right_padding, 0, 0),
        )
1004
1005
1006
1007
1008
1009
1010
1011
1012
        x = super().forward(x)
        return x


class NemoConvSubsampling(torch.nn.Module):
    """Convlutional subsampling module, taken from NeMo ASR
    (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
    34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)

1013
1014
    Striding Subsampling: "Speech-Transformer: A No-Recurrence
    Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
1015
1016
1017
    et al. (https://ieeexplore.ieee.org/document/8462506)


1018
    Compared with the EncoderConv2D (`input_layer: custom`), this is a
1019
1020
1021
1022
    much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
    Moreover, depthwise convolutions are used to reduce FLOPs, but the first
      layer is kept as a regular convolution so as not to degrade accuracy.

1023
    `Striding` and `dw_striding` are the same except that the latter uses
1024
1025
1026
1027
1028
1029
1030
    depthwise convolutions after the first layer, whereas the former does not.

    Args:
        subsampling_factor (int): Time reduction factor
        feat_in (int): size of the input features
        feat_out (int): size of the output features
        subsampling (str): The subsampling technique, choose from
1031
            {"striding", "dw-striding", "striding_conv1d",
1032
            "dw_striding_conv1d"}
1033
        conv_channels (int): Number of channels for the convolution layers,
1034
                            default is 256.
1035
        subsampling_conv_chunking_factor (int): Input chunking factor which
1036
1037
1038
1039
1040
1041
1042
            can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
        activation (Module): activation function, default is nn.ReLU()
        is_causal (bool): whether to use causal Conv1/2D, where each step will
            have limited access to locations on its right or left
    """

    def __init__(
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        self,
        feat_in: int,
        feat_out: int,
        subsampling_factor: int = 4,
        subsampling: str = "dw_striding",
        conv_channels: int = 256,
        subsampling_conv_chunking_factor: int = 1,
        activation: torch.nn.Module = nn.ReLU(),  # noqa: B008
        is_causal: bool = False,
    ) -> None:
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        super().__init__()
        self._subsampling = subsampling
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log(subsampling_factor, 2))
        self.subsampling_factor = subsampling_factor
        self.is_causal = is_causal
        self.subsampling_causal_cond = subsampling in (
            "dw_striding",
            "striding",
            "striding_conv1d",
        )

1070
1071
1072
1073
1074
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
1075
            raise ValueError(
1076
                "subsampling_conv_chunking_factor should be -1, 1, or a power of 2"
1077
            )
1078
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105

        in_channels = 1
        layers = []

        if subsampling == "dw_striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            # Layer 1
            if self.is_causal:
                layers.append(
                    CausalConv2D(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=None,
1106
1107
                    )
                )
1108
1109
1110
1111
1112
1113
1114
1115
            else:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
1116
1117
                    )
                )
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                            groups=in_channels,
1131
1132
                        )
                    )
1133
1134
1135
1136
1137
1138
1139
1140
1141
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
1142
1143
                        )
                    )
1144
1145
1146
1147
1148
1149
1150
1151
1152

                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
1153
1154
                    )
                )
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
1181
1182
                        )
                    )
1183
1184
1185
1186
1187
1188
1189
1190
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
1191
1192
                        )
                    )
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv1D(
                            in_channels=in_channels,
1217
1218
1219
1220
1221
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 1
                                else conv_channels
                            ),
1222
1223
1224
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
1225
1226
                        )
                    )
1227
1228
1229
1230
                else:
                    layers.append(
                        torch.nn.Conv1d(
                            in_channels=in_channels,
1231
1232
1233
1234
1235
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 1
                                else conv_channels
                            ),
1236
1237
1238
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
1239
1240
                        )
                    )
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "dw_striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            self._left_padding = (self._kernel_size - 1) // 2
            self._right_padding = (self._kernel_size - 1) // 2

            # Layer 1
1255
1256
            layers.extend(
                [
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                        groups=in_channels,
                    ),
                    torch.nn.Conv1d(
                        in_channels=in_channels,
1267
1268
1269
                        out_channels=(
                            feat_out if self._sampling_num == 1 else conv_channels
                        ),
1270
1271
1272
1273
1274
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    ),
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
                ]
            )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                layers.extend(
                    [
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        ),
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=(
                                feat_out
                                if self._sampling_num == i + 2
                                else conv_channels
                            ),
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            groups=1,
                        ),
                    ]
                )
1305
1306
1307
1308
1309
1310
1311
                layers.append(activation)
                in_channels = conv_channels

        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        if subsampling in ["dw_striding", "striding"]:
1312
1313
            out_length = calc_length_int(
                lengths=feat_in,
1314
1315
1316
1317
1318
1319
                all_paddings=self._left_padding + self._right_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                ceil_mode=self._ceil_mode,
                repeat_num=self._sampling_num,
            )
1320
            self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
1321
1322
1323
1324
1325
1326
1327
1328
1329
            self.conv2d_subsampling = True
        elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
            self.out = None
            self.conv2d_subsampling = False
        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        self.conv = torch.nn.Sequential(*layers)

1330
    def get_sampling_frames(self) -> list[int]:
1331
1332
        return [1, self.subsampling_factor]

1333
    def get_streaming_cache_size(self) -> list[int]:
1334
1335
        return [0, self.subsampling_factor + 1]

1336
    def forward(self, x: Tensor, mask: Tensor | None) -> tuple[Tensor, Tensor | None]:
1337
1338
1339
1340
        """
        Forward method for NeMo subsampling.

        Args:
1341
1342
            x: input tensor
            mask: input mask
1343
1344

        Returns:
1345
            x: Resulting tensor from subsampling (B, T //
1346
                time_reduction_factor, feat_out)
1347
            pad_mask: tensor of padded hidden state sequences (B, 1, T //
1348
1349
1350
1351
1352
                time_reduction_factor)
        """
        x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)

        # split inputs if chunking_factor is set
1353
        if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
1354
1355
1356
1357
1358
1359
            if self.subsampling_conv_chunking_factor == 1:
                # if subsampling_conv_chunking_factor is 1, we split only
                # if needed.
                # avoiding a bug / feature limiting indexing of tensors
                # to 2**31.
                # see https://github.com/pytorch/pytorch/issues/80020
1360
                x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
                need_to_split = torch.numel(x) > x_ceil
            else:
                # if subsampling_conv_chunking_factor > 1 we always split
                need_to_split = True

            if need_to_split:
                x, success = self.conv_split_by_batch(x)
                if not success:  # if unable to split by batch, try by channel
                    if self._subsampling == "dw_striding":
                        x = self.conv_split_by_channel(x)
                    else:
                        x = self.conv(x)  # try anyway
            else:
                x = self.conv(x)
        else:
            x = self.conv(x)

        # Flatten Channel and Frequency Axes
        if self.conv2d_subsampling:
            b, c, t, f = x.size()
            x = self.out(x.transpose(1, 2).reshape(b, t, -1))
        # Transpose to Channel Last mode
        else:
            x = x.transpose(1, 2)

        if mask is None:
            return x, None

        max_audio_length = x.shape[1]
        feature_lens = mask.sum(1)
        padding_length = torch.ceil(feature_lens / self.subsampling_factor)
        if self.is_causal and self.subsampling_causal_cond:
            feature_lens_remainder = feature_lens % self.subsampling_factor
            padding_length[feature_lens_remainder != 1] += 1
        pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
1396
1397
            padding_length.size(0), -1
        ) < padding_length.unsqueeze(1)
1398
1399
        return x, pad_mask.unsqueeze(1)

1400
    def reset_parameters(self) -> None:
1401
1402
1403
1404
1405
        # initialize weights
        if self._subsampling == "dw_striding":
            with torch.no_grad():
                # init conv
                scale = 1.0 / self._kernel_size
1406
                dw_max = (self._kernel_size**2) ** -0.5
1407
1408
1409
1410
1411
1412
                pw_max = self._conv_channels**-0.5

                torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
                torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)

                for idx in range(2, len(self.conv), 3):
1413
1414
1415
1416
                    torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
1417
1418
1419
1420

                # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
                # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
                # src/models/conformer_encoder.py#L487
1421
                fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
1422
1423
1424
                torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
                torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)

1425
    def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]:
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
        """Tries to split input by batch, run conv and concat results"""
        b, _, _, _ = x.size()
        if b == 1:  # can't split if batch size is 1
            return x, False

        if self.subsampling_conv_chunking_factor > 1:
            cf = self.subsampling_conv_chunking_factor
        else:
            # avoiding a bug / feature limiting indexing of tensors to 2**31
            # see https://github.com/pytorch/pytorch/issues/80020
            x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
            p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
            cf = 2**p

        new_batch_size = b // cf
        if new_batch_size == 0:  # input is too big
            return x, False

        return (
1445
1446
1447
            torch.cat(
                [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]
            ),
1448
1449
1450
            True,
        )

1451
    def conv_split_by_channel(self, x: Tensor) -> Tensor:
1452
        """For dw convs, tries to split input by time, run conv and concat
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
        results"""
        x = self.conv[0](x)  # full conv2D
        x = self.conv[1](x)  # activation

        for i in range(self._sampling_num - 1):
            _, c, t, _ = x.size()

            if self.subsampling_conv_chunking_factor > 1:
                cf = self.subsampling_conv_chunking_factor
            else:
                # avoiding a bug / feature limiting indexing of tensors
                # to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
                cf = 2**p

            new_c = int(c // cf)
            if new_c == 0:
                new_c = 1

            new_t = int(t // cf)
            if new_t == 0:
                new_t = 1

1477
1478
1479
            x = self.channel_chunked_conv(
                self.conv[i * 3 + 2], new_c, x
            )  # conv2D, depthwise
1480
1481
1482

            # splitting pointwise convs by time
            x = torch.cat(
1483
                [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)],
1484
1485
1486
1487
1488
                2,
            )  # conv2D, pointwise
            x = self.conv[i * 3 + 4](x)  # activation
        return x

1489
1490
1491
    def channel_chunked_conv(
        self, conv: torch.nn.Module, chunk_size: int, x: Tensor
    ) -> Tensor:
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
        """Performs channel chunked convolution"""

        ind = 0
        out_chunks = []
        for chunk in torch.split(x, chunk_size, 1):
            step = chunk.size()[1]

            if self.is_causal:
                chunk = nn.functional.pad(
                    chunk,
                    pad=(
                        self._kernel_size - 1,
                        self._stride - 1,
                        self._kernel_size - 1,
                        self._stride - 1,
                    ),
                )
                ch_out = nn.functional.conv2d(
                    chunk,
1511
1512
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
1513
1514
1515
1516
1517
1518
1519
                    stride=self._stride,
                    padding=0,
                    groups=step,
                )
            else:
                ch_out = nn.functional.conv2d(
                    chunk,
1520
1521
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
                    stride=self._stride,
                    padding=self._left_padding,
                    groups=step,
                )
            out_chunks.append(ch_out)
            ind += step

        return torch.cat(out_chunks, 1)

    def change_subsampling_conv_chunking_factor(
1532
1533
1534
1535
1536
1537
1538
        self, subsampling_conv_chunking_factor: int
    ) -> None:
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
1539
            raise ValueError(
1540
                "subsampling_conv_chunking_factor should be -1, 1, or a power of 2"
1541
1542
1543
1544
            )
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor


1545
1546
def calc_length_int(
    lengths: int,
1547
1548
1549
1550
1551
    all_paddings: int,
    kernel_size: int,
    stride: int,
    ceil_mode: bool,
    repeat_num: int = 1,
1552
1553
1554
1555
1556
1557
1558
) -> int:
    """Integer-only variant of calc_length for meta-safe shape computation.

    Computes the output length of a 1D convolution / pooling stack using
    the same formula as calc_length, but operates purely on Python numbers
    so it can be safely used during meta tensor initialization.
    """
1559
1560
    add_pad: float = all_paddings - kernel_size
    one: float = 1.0
1561
1562
1563
1564
1565
    length_f: float = float(lengths)
    for _ in range(repeat_num):
        length_f = (length_f + add_pad) / stride + one
        length_f = math.ceil(length_f) if ceil_mode else math.floor(length_f)
    return int(length_f)
1566
1567
1568
1569
1570
1571


####  multihead attention starts here
class AttModule(nn.Module):
    """Attention abstraction module"""

1572
    def __init__(self) -> None:
1573
1574
1575
        super().__init__()
        self.export_mode = False

1576
    def set_export(self, mode: bool = True) -> None:
1577
1578
1579
1580
1581
1582
        """set the export mode"""
        self.export_mode = mode

    def forward(
        self,
        x: Tensor,
1583
1584
1585
1586
        memory: Tensor | None = None,
        pos_emb: Tensor | None = None,
        att_mask: Tensor | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]:
1587
1588
1589
        """AttModule forward

        Args:
1590
1591
1592
1593
            x: input tensor.
            memory: memory tensor.
            pos_emb: positional encoder embedding.
            att_mask: attention mask tensor.
1594
1595
1596
1597
        """
        return x, memory, pos_emb, att_mask


1598
class AttBlock(BlockBase, AttModule):
1599
1600
    """Attention Block module to support both Attention and Block module."""

1601
    def memory_dims(self, max_len: bool = False) -> tuple[int, int]:
1602
1603
1604
1605
1606
        """memory dimensions"""
        return (1, self.input_size)


def masked_softmax(
1607
    scores: Tensor,
1608
    mask: Tensor | None,
1609
) -> Tensor:
1610
1611
1612
1613
    if mask is not None:
        mask = mask.unsqueeze(1).eq(0)  # (batch, 1, time1, time2)
        scores = scores.masked_fill(mask, -torch.inf)
        attn = torch.softmax(scores, dim=-1).masked_fill(
1614
1615
            mask, 0.0
        )  # (batch, head, time1, time2)
1616
1617
1618
1619
1620
1621
    else:
        attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
    return attn


class MultiHeadedAttention(nn.Module):
1622
    """Multi-Head Attention layer with optional relative position embedding
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    and GLU.

    Args:
        n_head: int
            the number of heads.
        n_feat: int
            input size features.
        dropout_rate: float
            dropout rate.
        attention_inner_dim: int, optional
            the attention dimension used in the class,
            it can be different from the input dimension n_feat.
            default: -1 (equal to n_feat).
        use_pt_scaled_dot_product_attention: bool, optional
            if set True, use pytorch scaled dot product attention in training.
1638
1639
            NOTE: this will NOT be used in ONNX decoding due to a lack of
            support.  In that case, we use the original attention
1640
1641
1642
            implementation, which shows no regression.
            default: False.
        n_value: int, optional
1643
            if set to values other than -1, use a different dimension for
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
            value. With the default value (i.e. -1), it is backward compatible.
        group_size: int, optional. must divide `n_head`
            if group_size > 1:       GQA
            if group_size = 1:       MHA
            if group_size = n_head:  MQA
    """

    inv_sqrt_d_k: torch.jit.Final[float]
    h: torch.jit.Final[int]
    h_k: torch.jit.Final[int]
    g: torch.jit.Final[int]

    def __init__(
        self,
1658
1659
1660
1661
1662
1663
1664
1665
        n_head: int,
        n_feat: int,
        dropout_rate: float,
        attention_inner_dim: int = -1,
        glu_type: str = "swish",
        bias_in_glu: bool = True,
        use_pt_scaled_dot_product_attention: bool = False,
        n_value: int = -1,
1666
        group_size: int = 1,
1667
    ) -> None:
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
        super().__init__()
        if n_value == -1:
            n_value = n_feat
        if attention_inner_dim == -1:
            attention_inner_dim = n_feat
        assert attention_inner_dim % n_head == 0

        # We assume d_v always equals d_k
        self.d_k = attention_inner_dim // n_head
        self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
        self.h = n_head
        assert n_head % group_size == 0, "group_size must divide n_head"
        self.g = group_size
        self.h_k = n_head // group_size

        self.linear_q = nn.Linear(n_feat, attention_inner_dim)
        self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
        self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
        self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)

1688
        self.attn = torch.jit.Attribute(None, Tensor | None)
1689
1690
        self.dropout = nn.Dropout(p=dropout_rate)
        self.dropout_rate = dropout_rate
1691
        self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708

        if use_pt_scaled_dot_product_attention and group_size > 1:
            raise ValueError("Cannot use PT Scaled Attention with GQA")

        # Torchscript eager quantization.  Note that these functions below are
        # NOOPs and have very little impact on performance unless quantization
        # is enabled.
        self.quant_q = torch.ao.quantization.QuantStub()
        self.quant_x = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.ffunc = torch.ao.nn.quantized.FloatFunctional()

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
1709
1710
1711
1712
        pos_k: Tensor | None,
        pos_v: Tensor | None,
        mask: Tensor | None,
        relative_attention_bias: Tensor | None = None,
1713
    ) -> Tensor:
1714
1715
1716
        """Compute 'Scaled Dot Product Attention'.

        Args:
1717
1718
1719
1720
1721
1722
            query: query tensor (batch, time1, size)
            key: key tensor (batch, time2, size)
            value: value tensor (batch, time1, size)
            pos_k: key tensor used for relative positional embedding.
            pos_v: value tensor used for relative positional embedding.
            mask: mask tensor (batch, time1, time2)
1723
            relative_attention_bias: bias added to attention logits w.r.t.
1724
                relative positions
1725
1726
1727
1728
                (1, n_head, time1, time2)
        """
        n_batch = query.size(0)

1729
1730
        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)  # (b, t, d)
        k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k)  # (b, t, d)
1731
        v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
1732
1733
1734
1735
1736
        q = (
            q.transpose(1, 2)
            if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
            else q.transpose(1, 2) * self.inv_sqrt_d_k
        )
1737
1738
1739
        k = k.transpose(1, 2)  # (batch, head_k, time2, d_k)
        v = v.transpose(1, 2)  # (batch, head_k, time2, d_k)

1740
        if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
            attn_mask = None
            if mask is not None:
                mask = mask.unsqueeze(1)
                if relative_attention_bias is not None:
                    attn_mask = mask + relative_attention_bias
                else:
                    attn_mask = mask
                if mask.dtype != q.dtype:
                    attn_mask = attn_mask.to(q.dtype)

1751
1752
            with torch.nn.attention.sdpa_kernel(
                [
cyyever's avatar
cyyever committed
1753
1754
1755
1756
                    torch.nn.attention.SDPBackend.FLASH_ATTENTION,
                    torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
                    torch.nn.attention.SDPBackend.MATH,
                    torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1757
1758
                ]
            ):
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
                x = torch.nn.functional.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=attn_mask,
                    dropout_p=self.dropout_rate,
                )
        else:
            if self.h != self.h_k:
                q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
                A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
            else:
                A = torch.matmul(q, k.transpose(-2, -1))
            if pos_k is not None:
                if self.h != self.h_k:
                    B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
                else:
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
                    reshape_q = (
                        q.contiguous()
                        .view(n_batch * self.h, -1, self.d_k)
                        .transpose(0, 1)
                    )  # (t1,nh,dk)
                    B = torch.matmul(
                        reshape_q, pos_k.transpose(-2, -1)
                    )  # pos_k: (t1,dk,t2)
                    B = B.transpose(0, 1).view(
                        n_batch, self.h, pos_k.size(0), pos_k.size(1)
                    )
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
                scores = A + B
            else:
                scores = A

            if relative_attention_bias is not None:
                scores = scores + relative_attention_bias

            attn = masked_softmax(scores, mask)  # (batch, head, time1, time2)

            self.attn = attn

            p_attn = self.dropout(attn)
1799
            x = torch.matmul(p_attn.to(v.dtype), v)  # (batch, head, time1, d_k)
1800
            if pos_v is not None:
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
                reshape_attn = (
                    p_attn.contiguous()
                    .view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
                    .transpose(0, 1)
                )  # (t1, bh, t2)

                attn_v = (
                    torch.matmul(reshape_attn, pos_v)
                    .transpose(0, 1)
                    .contiguous()
                    .view(n_batch, self.h, pos_v.size(0), self.d_k)
                )
1813
                x = x + attn_v
1814
1815
1816
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
        )  # (batch, time1, d_model)
1817
1818
1819
1820
1821
1822
1823
1824

        return self.linear_out(x)  # (batch, time1, d_model)


class MultiSequential(torch.nn.Sequential):
    """Multi-input multi-output torch.nn.Sequential"""

    @torch.jit.ignore
1825
    def forward(self, *args) -> tuple:
1826
1827
1828
1829
1830
1831
        """Forward method implementation."""
        for m in self:
            args = m(*args)
        return args


1832
def get_offset(input_layer: str, time_reduction: int) -> int:
1833
    """Get an offset. We will use the offset for determining #frames of a
1834
1835
1836
    subsampled feature.

    Args:
1837
1838
        input_layer: Type of an input layer
        time_reduction: time reduction factor for downsampling a feature
1839
1840
1841
1842
1843
    Returns:
        int: offset
    """
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
        return 3
1844
    if input_layer in ("conv2d",) and time_reduction == 6:
1845
1846
1847
1848
1849
1850
        return 1
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
        return 7
    return 0


1851
def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor:
1852
    """
1853
1854
    For a given tensor with shape of (N, T, D), if sequence length T is
    longer than max_seq_len, this function unfold it to a
1855
1856
    (NT', max_seq_len, D) where T' is T // max_seq_len.
    Args:
1857
1858
        xs_pad: input tensor with shape (N, T, D)
        max_seq_len: maximum sequence length
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
    """
    _, _, D = xs_pad.shape
    xs_pad = xs_pad.transpose(-1, -2)  # convert to N, D, T
    # N x D x 1 x T => N x (D x max_seq_len) x T'
    xs_pad = F.unfold(
        xs_pad[..., None, :],
        kernel_size=(1, max_seq_len),
        stride=(1, max_seq_len),
    )
    new_bsz, _, slen = xs_pad.shape
    # N x D x max_seq_len x T'
    xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
    # N x T' x max_seq_len x D
    xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
    # NT' x max_seq_len x D
    xs_pad = xs_pad.view(-1, max_seq_len, D)
    return xs_pad