phi4mm_audio.py 49.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import abc
import math
10
from typing import Any, Literal
11
12
13
14
15
16

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
17
18
19
    CheckpointWrapper,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
20
21
22
from transformers import PretrainedConfig

from vllm.model_executor.models.phi4mm_utils import (
23
24
25
26
27
28
29
30
31
32
33
34
    AbsolutePositionalEncoding,
    ConvModule,
    FeedForward,
    MeanVarianceNormLayer,
    MultiHeadedAttention,
    MultiSequential,
    NemoConvSubsampling,
    T5RelativeAttentionLogitBias,
    adaptive_enc_mask,
    get_offset,
    unfold_tensor,
)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


class ConformerEncoderLayer(nn.Module):
    """ConformerEncoder Layer module.
    for more details see conformer paper:
        https://arxiv.org/abs/2005.08100
    This module implement the Conformer block layer.

    Args:
        d_model: int
            attention dim.
        ext_pw_out_channel: int
            if > 0, ext_pw_out_channel is a dim channel size
             for the last pointwise conv after swish activation.
        depthwise_seperable_out_channel: int
50
            if set different to 0, the number of
51
             depthwise_seperable_out_channel will be used as a
52
             channel_out of the second conv1d layer.
53
             otherwise, it equals to 0, the second conv1d layer is skipped.
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
             will be used to compute the hidden channels of the Conv1D.
        n_head: int
            the number of heads for multihead attention module.
        d_ffn: int
            output size of the feed_forward blocks.
        ext_pw_kernel_size: int
            kernel size of the conv pointwise of the conformer.
        kernel_size: int
            kernel size.
        dropout_rate: float
            dropout rate.
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation
            in ConvModule layer of the conformer.
            default False
        activation: str, optional
            activation function name,
            one of ["relu", "swish", "sigmoid"],
            sigmoid activation is only used with "glu_in_fnn=True",
            default "relu".
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
            default 0.
        chunk_size: int, optional
            chunk_size for cnn. default 18
        conv_activation: str, optional
            activation function used in ConvModule part
            of the conformer, default "relu".
        conv_glu_type: str, optional
            activation function used for the glu inside
            the ConvModule part of the conformer.
            default: "sigmoid".
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU.
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
omahs's avatar
omahs committed
102
        attention_inner_dim: int, optional
103
            if equal to -1, attention dim for linears k/q/v is
omahs's avatar
omahs committed
104
            equal to d_model. otherwise attention_inner_dim is used.
105
106
107
108
109
            default -1.
        attention_glu_type: str, optional
            activation function for glu used in the multihead attention,
             default "swish".
        activation_checkpointing: str, optional
110
            a dictionary of {"module","interval","offload"}, where
111
112
113
114
115
116
117
118
119
120
121
122
123
124
                "module": str
                    accept ["transformer", "attention"] to select
                    which module should do activation checkpointing.
                "interval": int, default 1,
                    interval of applying activation checkpointing,
                    interval = 1 means that we apply checkpointing
                    on every layer (if activation), otherwise,
                    we apply it every x interval.
                "offload": bool, default False,
                    if set to True, we offload activation to cpu and
                    reload it during backward, otherwise,
                    we recalculate activation in backward.
            default "".
        export: bool, optional
125
            if set to True, it removes the padding from convolutional layers
126
127
128
             and allow the onnx conversion for inference.
              default False.
        use_pt_scaled_dot_product_attention: bool, optional
129
            if set to True, use pytorch's scaled dot product attention
130
131
            implementation in training.
        attn_group_sizes: int, optional
132
            the number of groups to use for attention, default 1
133
134
135
            (Multi-Head Attention),
            1 = typical Multi-Head Attention,
            1 < attn_group_sizes < attention_heads = Grouped-Query Attention
136
            attn_group_sizes = attention_heads = Multi-Query Attention
137
138
139
140
    """

    def __init__(
        self,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        d_model: int = 512,
        ext_pw_out_channel: int = 0,
        depthwise_seperable_out_channel: int = 256,
        depthwise_multiplier: int = 1,
        n_head: int = 4,
        d_ffn: int = 2048,
        ext_pw_kernel_size: int = 1,
        kernel_size: int = 3,
        dropout_rate: float = 0.1,
        causal: bool = False,
        batch_norm: bool = False,
        activation: str = "relu",
        chunk_se: int = 0,
        chunk_size: int = 18,
        conv_activation: str = "relu",
        conv_glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        linear_glu_in_convm: bool = False,
        attention_inner_dim: int = -1,
        attention_glu_type: str = "swish",
        activation_checkpointing: str = "",
        export: bool = False,
        use_pt_scaled_dot_product_attention: bool = False,
164
        attn_group_sizes: int = 1,
165
    ) -> None:
166
167
168
169
170
171
172
173
174
175
        super().__init__()

        self.feed_forward_in = FeedForward(
            d_model=d_model,
            d_inner=d_ffn,
            dropout_rate=dropout_rate,
            activation=activation,
            bias_in_glu=bias_in_glu,
        )

176
        self.self_attn = MultiHeadedAttention(
177
178
179
            n_head,
            d_model,
            dropout_rate,
omahs's avatar
omahs committed
180
            attention_inner_dim,
181
182
            attention_glu_type,
            bias_in_glu,
183
            use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
184
            group_size=attn_group_sizes,
185
        )
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        self.conv = ConvModule(
            d_model,
            ext_pw_out_channel,
            depthwise_seperable_out_channel,
            ext_pw_kernel_size,
            kernel_size,
            depthwise_multiplier,
            dropout_rate,
            causal,
            batch_norm,
            chunk_se,
            chunk_size,
            conv_activation,
            conv_glu_type,
            bias_in_glu,
            linear_glu_in_convm,
            export=export,
        )

        self.feed_forward_out = FeedForward(
            d_model=d_model,
            d_inner=d_ffn,
            dropout_rate=dropout_rate,
            activation=activation,
            bias_in_glu=bias_in_glu,
        )

        self.layer_norm_att = nn.LayerNorm(d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(
        self,
218
219
220
221
        x: torch.Tensor,
        pos_k: torch.Tensor,
        pos_v: torch.Tensor,
        mask: torch.Tensor,
222
        relative_attention_bias: Tensor | None = None,
223
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
224
225
226
        """ConformerEncoder forward.

        Args:
227
228
229
230
231
232
            x: input feature of shape (batch, max_time_in, size)
            pos_k: positional key embedding.
            pos_v: positional value embedding.
            mask: mask for x (batch, max_time_in)
            relative_attention_bias: bias added to attention logits w.r.t.
                relative positions (1, n_head, time1, time2)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        """
        x = x + 0.5 * self.feed_forward_in(x)
        norm_x = self.layer_norm_att(x)

        x = x + self.self_attn(
            norm_x,
            norm_x,
            norm_x,
            pos_k,
            pos_v,
            mask,
            relative_attention_bias=relative_attention_bias,
        )
        x = x + self.conv(x)
        x = x + 0.5 * self.feed_forward_out(x)

        out = self.layer_norm(x)

        return out, pos_k, pos_v, mask


class TransformerEncoderBase(abc.ABC, nn.Module):
    """The Base class for Transformer based encoders

    Please set causal = True in streaming model
    Args:
        input_size: int
            input feature dimension.
        chunk_size: int, list(int)
            Number of frames for each chunk
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training
            Some examples for the 2 cases:
            chunk_size = 12
            chunk_size = [6, 8, 12, 24]
        left_chunk: int, list(int)
            Number of chunks used for masking in streaming mode.
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training. When
            chunk_size is a list, left_chunk must be a list with same length.
            Some examples for the 2 cases:
            left_chunk = 6
            left_chunk = [12, 9, 6, 3]
        attention_dim: int, optional
            attention dimension. default 256.
        attention_heads: int, optional
            the number of heads. default 4
        input_layer: str, optional
            input layer type before Conformer,
            one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
            default "conv2d"
        cnn_out: int, optional
            the number of CNN channels before Conformer.
            default -1.
        cnn_layer_norm: bool, optional
            layer norm between Conformer and the first CNN.
            default False.
        time_reduction: int, optional
            time reduction factor
            default 4
        dropout_rate: float, optional
            dropout rate. default 0.1
        padding_idx: int, optional
            padding index for input_layer=embed
            default -1
        relative_attention_bias_args: dict, optional
            use more efficient scalar bias-based relative multihead attention
            (Q*K^T + B) implemented in cmb.basics.embedding.
            [T5/ALiBi]RelativeAttentionLogitBias
            usage: relative_attention_bias_args={"type": t5/alibi}
305
            additional method-specific arguments can be provided (see
306
307
308
309
310
311
312
313
314
315
316
317
318
            transformer_base.py)
        positional_dropout_rate: float, optional
            dropout rate after positional encoding. default 0.0
        nemo_conv_settings: dict, optional
            A dictionary of settings for NeMo Subsampling.
            default None
        conv2d_extra_padding: str, optional
            Add extra padding in conv2d subsampling layers. Choices are
            (feat, feat_time, none, True).
            if True or feat_time, the extra padding is added into non full
            supraframe utts in batch.
            Default: none
        attention_group_size: int, optional
319
            the number of groups to use for attention, default 1
320
321
            (Multi-Head Attention),
            1 = typical Multi-Head Attention,
322
            1 < attention_group_size < attention_heads = Grouped-Query
323
            Attention
324
            attention_group_size = attention_heads = Multi-Query Attention
325
326
327
328
    """

    def __init__(
        self,
329
        input_size: int,
330
331
        chunk_size: int | list[int],
        left_chunk: int | list[int],
332
333
334
335
336
337
338
339
        attention_dim: int = 256,
        attention_heads: int = 4,
        input_layer: str = "nemo_conv",
        cnn_out: int = -1,
        cnn_layer_norm: bool = False,
        time_reduction: int = 4,
        dropout_rate: float = 0.0,
        padding_idx: int = -1,
340
        relative_attention_bias_args: dict[str, Any] | None = None,
341
        positional_dropout_rate: float = 0.0,
342
        nemo_conv_settings: dict[str, Any] | None = None,
343
        conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
344
        attention_group_size: int = 1,
345
        encoder_embedding_config: dict[str, Any] | None = None,
346
    ) -> None:
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        super().__init__()
        self.input_size = input_size
        self.input_layer = input_layer
        self.chunk_size = chunk_size
        self.left_chunk = left_chunk
        self.attention_dim = attention_dim
        self.num_heads = attention_heads
        self.attention_group_size = attention_group_size
        self.time_reduction = time_reduction
        self.nemo_conv_settings = nemo_conv_settings
        self.encoder_embedding_config = encoder_embedding_config

        if self.input_layer == "nemo_conv":
            default_nemo_conv_settings = {
                "subsampling": "dw_striding",
                "subsampling_factor": self.time_reduction,
                "feat_in": input_size,
                "feat_out": attention_dim,
                "conv_channels": 256,
                "subsampling_conv_chunking_factor": 1,
                "activation": nn.ReLU(),
                "is_causal": False,
            }
            # Override any of the defaults with the incoming, user settings
            if nemo_conv_settings:
                default_nemo_conv_settings.update(nemo_conv_settings)
                for i in ["subsampling_factor", "feat_in", "feat_out"]:
374
375
376
                    assert i not in nemo_conv_settings, (
                        "{i} should be specified outside of the NeMo dictionary"
                    )
377

378
379
380
            self.embed = NemoConvSubsampling(
                **default_nemo_conv_settings,
            )
381
382
383
        else:
            raise ValueError("unknown input_layer: " + input_layer)

384
385
386
        self.pos_emb = AbsolutePositionalEncoding(
            attention_dim, positional_dropout_rate
        )
387
388
389

        self.relative_attention_bias_type = (
            relative_attention_bias_args.get("type")
390
391
392
            if relative_attention_bias_args
            else None
        )
393
        if self.relative_attention_bias_type == "t5":
394
395
396
            assert self.num_heads % self.attention_group_size == 0, (
                "attention_group_size must divide n_head"
            )
397
398
399
            self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
                self.num_heads // self.attention_group_size,
                max_distance=relative_attention_bias_args.get(
400
401
402
                    "t5_bias_max_distance", 1000
                ),
                symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
403
404
405
406
            )
        else:
            raise NotImplementedError

407
        self.encoder_embedding = MeanVarianceNormLayer(
408
409
            self.encoder_embedding_config["input_size"]
        )
410

411
    def compute_lens_change(
412
413
        self, feature_lens: int | torch.Tensor
    ) -> int | torch.Tensor:
414
415
416
        """feature_lens: int
        return updated feature lens.

417
418
        This used to return a different lambda function for each case that
        computed the right thing.  That does not work within Torchscript.
419
420
421
422
423
424
        If you really need this to be faster, create nn.Module()-s for all
        the cases and return one of them.  Torchscript does support that.
        """
        if self.input_layer == "nemo_conv":
            # Handle the special causal case
            subsampling_causal_cond = self.nemo_conv_settings.get(
425
426
427
428
429
430
                "subsampling", "dw_striding"
            ) in [
                "dw_striding",
                "striding",
                "striding_conv1d",
            ]
431
432
            is_causal = self.nemo_conv_settings.get("is_causal", False)
            if is_causal and subsampling_causal_cond:
433
434
435
436
437
                lens_change = (
                    torch.ceil(feature_lens / self.time_reduction).long()
                    if isinstance(feature_lens, Tensor)
                    else math.ceil(feature_lens / self.time_reduction)
                )
438
439
440
441
442
443
                feature_lens_remainder = feature_lens % self.time_reduction
                if isinstance(feature_lens, Tensor):
                    lens_change[feature_lens_remainder != 1] += 1
                elif feature_lens_remainder != 1:
                    lens_change += 1
                return lens_change
444
            ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
445
446
447
            return ceil_func(feature_lens / self.time_reduction)

    @abc.abstractmethod
448
    def forward(self) -> Any:
449
450
        """Abstract forward method implementation."""

451
    def _chunk_size_selection(
452
        self,
453
454
        chunk_size: int | list[int] | None = None,
        left_chunk: int | list[int] | None = None,
455
    ) -> tuple[int, int]:
456
457
458
459
460
461
462
463
464
        """If chunk size is a list, we will randomly select a chunk size."""

        if chunk_size is None:
            chunk_size = self.chunk_size
        if left_chunk is None:
            left_chunk = self.left_chunk
        if isinstance(chunk_size, list):
            # Variable chunk size during training
            chunk_size_index = int(
465
466
                torch.randint(low=0, high=len(chunk_size), size=(1,))
            )
467
468
469
            chunk_size_train_eff = chunk_size[chunk_size_index]
            if not isinstance(left_chunk, list):
                raise ValueError(
470
471
                    "Since chunk_size is a list, left_chunk must be a list"
                )
472
473
            if len(left_chunk) != len(chunk_size):
                raise ValueError(
474
                    "The length of left_chunk must be the same as length of chunk_size."
475
476
477
478
479
480
481
482
                )
            left_chunk_train_eff = left_chunk[chunk_size_index]
        else:
            chunk_size_train_eff = chunk_size
            left_chunk_train_eff = left_chunk

        return chunk_size_train_eff, left_chunk_train_eff

483
    def _get_embed_class(self, embed: nn.Module) -> nn.Module:
484
485
486
487
488
489
490
491
492
493
        # pylint: disable=protected-access
        is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
        is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
        embed_class = embed
        if is_embed_using_act_chkpt:
            embed_class = embed._checkpoint_wrapped_module
        if is_embed_fsdp_wrapped:
            embed_class = embed.module
        return embed_class

494
    def _forward_embeddings_core(
495
496
        self, input_tensor: torch.Tensor, masks: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
497
498
499
500
501
        embed_class = self._get_embed_class(self.embed)
        assert isinstance(embed_class, NemoConvSubsampling)
        input_tensor, masks = self.embed(input_tensor, masks)
        return input_tensor, masks

502
503
    def _position_embedding(
        self, input_tensor: torch.Tensor
504
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
505
506
507
508
        pos_k = None
        pos_v = None
        if self.relative_attention_bias_layer is None:
            input_tensor = self.pos_emb(
509
510
                input_tensor
            )  # default to add abs sinusoid embedding
511
512
        return pos_k, pos_v

513
514
515
516
    def _streaming_mask(
        self,
        seq_len: int,
        batch_size: int,
517
518
        chunk_size: int | list[int],
        left_chunk: int | list[int],
519
520
521
522
    ) -> torch.Tensor:
        chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
            chunk_size, left_chunk
        )
523
524
525
526
527

        # Create mask matrix for streaming
        # S stores start index. if chunksize is 18, s is [0,18,36,....]
        chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)

528
529
530
531
532
533
534
        enc_streaming_mask = (
            adaptive_enc_mask(
                seq_len, chunk_start_idx, left_window=left_chunk_train_eff
            )
            .unsqueeze(0)
            .expand([batch_size, -1, -1])
        )
535
536
        return enc_streaming_mask

537
538
539
540
    def forward_embeddings(
        self,
        xs_pad: torch.Tensor,
        masks: torch.Tensor,
541
542
543
        chunk_size_nc: int | list[int] | None = None,
        left_chunk_nc: int | list[int] | None = None,
    ) -> (
544
545
        tuple[
            torch.Tensor,
546
547
            torch.Tensor | None,
            torch.Tensor | None,
548
549
            torch.Tensor,
            torch.Tensor,
550
551
        ]
        | tuple[
552
            torch.Tensor,
553
554
            torch.Tensor | None,
            torch.Tensor | None,
555
556
557
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
558
559
        ]
    ):
560
561
562
563
564
565
566
        """Forwarding the inputs through the top embedding layers

        Args:
            xs_pad: torch.Tensor
                input tensor
            masks: torch.Tensor
                input mask
567
            chunk_size_nc: (optional, default is None) chunk size for
568
569
570
571
572
573
574
575
576
577
578
579
                            non-causal layers
            left_chunk_nc: (optional, default is None) # of left chunks for
                            non-causal layers
        """
        # pylint: disable=R0915
        # get new lens.
        seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
        if seq_len <= 0:
            raise ValueError(
                f"""The sequence length after time reduction is invalid: 
                {seq_len}. Your input feature is too short. Consider 
                filtering out the very short sentence from data 
580
581
                loader""",
            )
582
583
584

        batch_size = xs_pad.shape[0]

585
586
587
        enc_streaming_mask = self._streaming_mask(
            seq_len, batch_size, self.chunk_size, self.left_chunk
        )
588
589
590
        device = xs_pad.device
        enc_streaming_mask = enc_streaming_mask.to(device)
        xs_pad = xs_pad.to(device)
591
592

        input_tensor = xs_pad
593
        input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
594
595
596
597
598
599
600
601
602
603
604

        streaming_mask = enc_streaming_mask
        if streaming_mask is not None and masks is not None:
            hs_mask = masks & streaming_mask
        elif masks is not None:
            hs_mask = masks
        else:
            hs_mask = streaming_mask

        if chunk_size_nc is not None:
            enc_streaming_mask_nc = self._streaming_mask(
605
606
                seq_len, batch_size, chunk_size_nc, left_chunk_nc
            )
607
608
            if device.type != "cpu":
                enc_streaming_mask_nc = enc_streaming_mask_nc.to(device)
609
610
611
612
613
614
615
616
617
618
619
620
621
            if masks is not None:
                hs_mask_nc = masks & enc_streaming_mask_nc
            else:
                hs_mask_nc = enc_streaming_mask_nc
        else:
            hs_mask_nc = None

        pos_k, pos_v = self._position_embedding(input_tensor)

        if chunk_size_nc is None:
            return input_tensor, pos_k, pos_v, hs_mask, masks
        return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc

622
    def get_offset(self) -> int:
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        """Returns offset used when retaining inputs for decoding.

        This is essentially, how many additional frames have to be added to
        the front-end CNN input to ensure it can produce a single output.
        So if the "padding" parameter is 0, typically offset will be > 0.
        """
        return get_offset(self.input_layer, self.time_reduction)


class ConformerEncoder(TransformerEncoderBase):
    """ConformerEncoder module.
    see original paper for more details:
        https://arxiv.org/abs/2005.08100

    Please set causal = True in streaming model
    Args:
        input_size: int
            input feature dimension.
        chunk_size: int, list(int)
            Number of frames for each chunk
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training
            Some examples for the 2 cases:
            chunk_size = 12
            chunk_size = [6, 8, 12, 24]
        left_chunk: int, list(int)
            Number of chunks used for masking in streaming mode.
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training. When
            chunk_size is a list, left_chunk must be a list with same length.
            Some examples for the 2 cases:
            left_chunk = 6
            left_chunk = [12, 9, 6, 3]
        num_lang: int
659
660
            This parameter is used to store the number of languages in the
            lang_dict, only used for multiseed/multilingual models.
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
            default None.
        attention_dim: int, optional
            attention dimension. default 256.
        attention_heads: int, optional
            the number of heads. default 4
        linear_units:
            the number of units of position-wise feed forward.
            default 2048
        num_block:
            number of Transformer layer. default 6
        dropout_rate: float, optional
            dropout rate. default 0.1
        input_layer: str, optional
            input layer type before Conformer,
            one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
            default "conv2d"
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation
            in ConvModule layer of the conformer.
            default False
        cnn_out: int, optional
            the number of CNN channels before Conformer.
            default -1.
        cnn_layer_norm: bool, optional
            layer norm between Conformer and the first CNN.
            default False.
        ext_pw_out_channel: int, optional
            the number of channel for CNN
Jiayi Yan's avatar
Jiayi Yan committed
692
            before depthwise_separable_CNN.
693
694
            If 0 then use linear. default 0.
        ext_pw_kernel_size: int, optional
Jiayi Yan's avatar
Jiayi Yan committed
695
            kernel size of N before depthwise_separable_CNN.
696
697
698
699
            only work for ext_pw_out_channel > 0.
            default 1
        depthwise_seperable_out_channel: int, optional
            the number of channel for
Jiayi Yan's avatar
Jiayi Yan committed
700
            depthwise_separable_CNN.
701
702
703
            default 256.
        depthwise_multiplier: int, optional
            the number of multiplier for
Jiayi Yan's avatar
Jiayi Yan committed
704
            depthwise_separable_CNN.
705
706
707
708
709
710
711
712
713
            default 1.
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
            default 0.
        kernel_size: int, optional
Jiayi Yan's avatar
Jiayi Yan committed
714
            the number of kernels for depthwise_separable_CNN.
715
716
717
718
719
720
721
722
723
            default 3.
        activation: str, optional
            FeedForward block activation.
            one of ["relu", "swish", "sigmoid"]
            default "relu".
        conv_activation: str, optional
            activation function used in ConvModule part
            of the conformer, default "relu".
        conv_glu_type: str, optional
Jiayi Yan's avatar
Jiayi Yan committed
724
            activation used use glu in depthwise_separable_CNN,
725
726
727
728
729
730
731
732
733
734
735
736
            default "sigmoid"
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU. default True
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        attention_glu_type: str
            only work for glu_in_attention !=0
            default "swish".
        export: bool, optional
737
            if set to True, it removes the padding from convolutional layers
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
             and allow the onnx conversion for inference.
              default False.
        activation_checkpointing: str, optional
            a dictionarry of {"module","interval","offload"}, where
                "module": str
                    accept ["transformer", "attention"] to select
                    which module should do activation checkpointing.
                "interval": int, default 1,
                    interval of applying activation checkpointing,
                    interval = 1 means that we apply checkpointing
                    on every layer (if activation), otherwise,
                    we apply it every x interval.
                "offload": bool, default False,
                    if set to True, we offload activation to cpu and
                    reload it during backward, otherwise,
                    we recalculate activation in backward.
            default "".
        extra_layer_output_idx: int
            the layer index to be exposed.
        relative_attention_bias_args: dict, optional
758
            use more efficient scalar bias-based relative multihead attention
759
760
761
            (Q*K^T + B) implemented in cmb.basics.embedding.
            [T5/ALiBi]RelativeAttentionLogitBias
            usage: relative_attention_bias_args={"type": t5/alibi}
762
            additional method-specific arguments can be provided (see
763
764
765
766
            transformer_base.py)
        time_reduction: int optional
            time reduction factor
            default 4
767
        use_pt_scaled_dot_product_attention: whether to use pytorch scaled
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
            dot product attention in training.
            Default: False
        nemo_conv_settings: dict, optional
            A dictionary of settings for NeMo Subsampling.
            default: None
            usage: nemo_conv_settings=
                {
                    "subsampling":
                    dw_striding/striding/dw_striding_conv1d/striding_conv1d,
                    "conv_channels": int,
                    "subsampling_conv_chunking_factor": int,
                    "is_causal": True/False
                }
        conv2d_extra_padding: str, optional
            Add extra padding in conv2d subsampling layers. Choices are
            (feat, feat_time, none, True)
            Default: none
785
        replication_pad_for_subsample_embedding:  For batched-streaming
786
787
788
789
            decoding, use "replication" padding for the cache at start of
            utterance.
            Default: False
        attention_group_size: int, optional
790
            the number of groups to use for attention, default 1
791
792
793
794
            (Multi-Head Attention),
            1 = typical Multi-Head Attention,
            1 < attention_group_size < attention_heads = Grouped-Query
            Attention
795
            attention_group_size = attention_heads = Multi-Query Attention
796
797
    """

798
    extra_multi_layer_output_idxs: list[int]
799
800
801

    def __init__(  # pylint: disable-all
        self,
802
        input_size: int,
803
804
805
        chunk_size: int | list[int],
        left_chunk: int | list[int],
        num_lang: int | None = None,
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
        attention_dim: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        input_layer: str = "nemo_conv",
        causal: bool = True,
        batch_norm: bool = False,
        cnn_out: int = -1,
        cnn_layer_norm: bool = False,
        ext_pw_out_channel: int = 0,
        ext_pw_kernel_size: int = 1,
        depthwise_seperable_out_channel: int = 256,
        depthwise_multiplier: int = 1,
        chunk_se: int = 0,
        kernel_size: int = 3,
        activation: str = "relu",
        conv_activation: str = "relu",
        conv_glu_type: str = "sigmoid",
        bias_in_glu: bool = True,
        linear_glu_in_convm: bool = False,
        attention_glu_type: str = "swish",
        export: bool = False,
        extra_layer_output_idx: int = -1,
        extra_multi_layer_output_idxs: list[int] = [],  # noqa
        activation_checkpointing: str = "",
832
        relative_attention_bias_args: dict[str, Any] | None = None,
833
834
        time_reduction: int = 4,
        use_pt_scaled_dot_product_attention: bool = False,
835
        nemo_conv_settings: dict[str, Any] | None = None,
836
        conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
837
838
        replication_pad_for_subsample_embedding: bool = False,
        attention_group_size: int = 1,
839
        encoder_embedding_config: dict[str, Any] | None = None,
840
    ) -> None:
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
        super().__init__(
            input_size,
            chunk_size,
            left_chunk,
            attention_dim,
            attention_heads,
            input_layer,
            cnn_out,
            cnn_layer_norm,
            time_reduction,
            dropout_rate=dropout_rate,
            relative_attention_bias_args=relative_attention_bias_args,
            positional_dropout_rate=0.0,
            nemo_conv_settings=nemo_conv_settings,
            conv2d_extra_padding=conv2d_extra_padding,
            attention_group_size=attention_group_size,
            encoder_embedding_config=encoder_embedding_config,
        )
        self.num_blocks = num_blocks
        self.num_lang = num_lang
        self.kernel_size = kernel_size
        self.replication_pad_for_subsample_embedding: bool = (
863
864
865
866
867
            replication_pad_for_subsample_embedding
        )
        assert self.num_heads % attention_group_size == 0, (
            "attention_group_size must divide n_head"
        )
868
869
        self.num_heads_k = self.num_heads // attention_group_size

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        self.encoders = MultiSequential(
            *[
                ConformerEncoderLayer(
                    d_model=attention_dim,
                    ext_pw_out_channel=ext_pw_out_channel,
                    depthwise_seperable_out_channel=depthwise_seperable_out_channel,
                    depthwise_multiplier=depthwise_multiplier,
                    n_head=attention_heads,
                    d_ffn=linear_units,
                    ext_pw_kernel_size=ext_pw_kernel_size,
                    kernel_size=kernel_size,
                    dropout_rate=dropout_rate,
                    causal=causal,
                    batch_norm=batch_norm,
                    activation=activation,
                    chunk_se=chunk_se,
                    chunk_size=chunk_size,
                    conv_activation=conv_activation,
                    conv_glu_type=conv_glu_type,
                    bias_in_glu=bias_in_glu,
                    linear_glu_in_convm=linear_glu_in_convm,
                    attention_glu_type=attention_glu_type,
                    activation_checkpointing=activation_checkpointing,
                    export=export,
                    use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
                    attn_group_sizes=attention_group_size,
                )
                for _ in range(num_blocks)
            ]
        )
900
901
902
903
904
905
        self.extra_layer_output_idx = extra_layer_output_idx
        self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
        # Make a zeros scalar we can use in get_initial_state to determine
        # the device and the needed dtype:
        self.register_buffer("dev_type", torch.zeros(()), persistent=False)

906
    def init_relative_attention_bias(
907
        self, input_tensor: torch.Tensor
908
    ) -> torch.Tensor | None:
909
910
911
        if self.relative_attention_bias_layer:
            return self.relative_attention_bias_layer(input_tensor)

912
    def calculate_hs_mask(
913
        self, xs_pad: torch.Tensor, device: torch.device, mask: torch.Tensor | None
914
    ) -> torch.Tensor:
915
916
        max_audio_length = xs_pad.shape[1]
        batch_size = xs_pad.shape[0]
917
918
919
        enc_streaming_mask = self._streaming_mask(
            max_audio_length, batch_size, self.chunk_size, self.left_chunk
        )
920
921
922
923
924
925
        enc_streaming_mask = enc_streaming_mask.to(device)
        if mask is None:
            return enc_streaming_mask

        feature_lens = mask.sum(1)
        padding_length = feature_lens
926
927
928
        pad_mask = torch.arange(0, max_audio_length, device=device).expand(
            padding_length.size(0), -1
        ) < padding_length.unsqueeze(1)
929
930
931
932
933
        pad_mask = pad_mask.unsqueeze(1)
        pad_mask = pad_mask & enc_streaming_mask
        return pad_mask

    @torch.jit.ignore
934
935
936
    def forward(
        self, xs_pad: torch.Tensor, masks: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
937
938
939
940
941
942
943
944
945
946
        """Conformer Forward function

        Args:
            xs_pad: torch.Tensor
                input tensor
            masks: torch.Tensor
                post-embedding input lengths
        """
        xs_pad = self.encoder_embedding(xs_pad)
        input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
947
948
            xs_pad, masks
        )
949
950
951

        unfolded = False
        ori_bz, seq_len, D = input_tensor.shape
952
        max_seq_len = 500  # maximum position for absolute positional encoding
953
954
955
956
957
958
959
960
961
962
963
        if seq_len > max_seq_len:
            # audio sequence is longer than max_seq_len, unfold it into chunks
            # of max_seq_len
            unfolded = True
            # the unfold op will drop residual frames, pad it to the multiple
            # of max_seq_len
            if seq_len % max_seq_len > 0:
                chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
            else:
                chunk_pad_size = 0
            if chunk_pad_size > 0:
964
965
966
                input_tensor_pad = F.pad(
                    input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
                )
967
968
969
970
971
972
                input_tensor = input_tensor_pad.to(input_tensor.device)
            input_tensor = unfold_tensor(input_tensor, max_seq_len)
            if masks is not None:
                # revise hs_mask here because the previous calculated hs_mask
                # did not consider extra pad
                subsampled_pad_mask = masks.squeeze(
973
974
                    1
                )  # [bz, subsampled_unmask_seq_len]
975
                extra_padded_subsamlped_pad_mask = F.pad(
976
977
978
                    subsampled_pad_mask, (0, chunk_pad_size), "constant", False
                )  # extra padding to the pad mask
                extra_padded_subsamlped_pad_mask = (
979
                    extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
980
                )
981
982
983
984
                masks_unfold = unfold_tensor(
                    extra_padded_subsamlped_pad_mask, max_seq_len
                )  # unfold the pad mask like we did to the input tensor
                masks_unfold = masks_unfold.squeeze(
985
986
                    -1
                ).bool()  # unfold op does not support bool tensor
987
988
989
990
991
992
993
994
            else:
                masks_unfold = None
            hs_mask = self.calculate_hs_mask(
                input_tensor, input_tensor.device, masks_unfold
            )  # calculate hs_mask based on the unfolded pad mask

        # layer_emb = None

995
        relative_attention_bias = self.init_relative_attention_bias(input_tensor)
996

997
998
999
        _simplified_path = (
            self.extra_layer_output_idx == -1 and relative_attention_bias is None
        )
1000
1001

        if _simplified_path:
1002
            input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        else:
            for i, layer in enumerate(self.encoders):
                input_tensor, _, _, _ = layer(
                    input_tensor,
                    pos_k,
                    pos_v,
                    hs_mask,
                    relative_attention_bias=relative_attention_bias,
                )

                # if i == self.extra_layer_output_idx:
                #     layer_emb = input_tensor

        if unfolded:
            embed_dim = input_tensor.shape[-1]
            input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
            # if we ever padded before unfolding, we need to remove the padding
            if chunk_pad_size > 0:
                input_tensor = input_tensor[:, :-chunk_pad_size, :]

        return input_tensor, masks  # , layer_emb


class WindowQformer(nn.Module):
    """Window-level Qformer"""

    def __init__(
        self,
        window_size: int = 8,
        num_queries: int = 1,
        num_blocks: int = 2,
        attention_dim: int = 512,
        attention_heads: int = 8,
        linear_units: int = 2048,
        dropout_rate: float = 0.0,
        normalize_before: bool = True,
    ):
        super().__init__()

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        self.decoders = nn.ModuleList(
            [
                nn.TransformerDecoderLayer(
                    d_model=attention_dim,
                    nhead=attention_heads,
                    dim_feedforward=linear_units,
                    dropout=dropout_rate,
                    activation="relu",
                    batch_first=True,
                    norm_first=normalize_before,  # TODO need to verify
                )
                for _ in range(num_blocks)
            ]
        )
1056
1057

        self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
1058
1059
1060
        self.after_norm = (
            nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
        )
1061
1062
        self.window_size = window_size

1063
    def forward(
1064
1065
        self,
        audio_embed: torch.Tensor,
1066
1067
1068
        mask: torch.Tensor | None,
        embed_len: int | None = None,
    ) -> tuple[torch.Tensor, int | None]:
1069
1070
1071
1072
1073
1074
1075
        """forward decoder"""
        # audio_embed: N x T x D => N x D x T

        audio_embed = audio_embed.transpose(1, 2)
        # audio_embed: N x D x 1 x T => N x DK x T'
        padding = audio_embed.shape[-1] % self.window_size
        if padding > 0:
1076
1077
1078
            audio_embed = F.pad(
                audio_embed, (0, self.window_size - padding), "constant", 0
            )
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

        embed_chunk = F.unfold(
            audio_embed[..., None, :],
            kernel_size=(1, self.window_size),
            stride=(1, self.window_size),
        )
        bsz, _, slen = embed_chunk.shape
        # N x D x K x T'
        embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
        # N x T' x K x D
        embed_chunk = embed_chunk.transpose(1, 3).contiguous()
        # NT' x K x D
        embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
        # NT' x 1 x D
        q = self.queries.expand(bsz * slen, -1, -1)
        for layer in self.decoders:
1095
            q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110

        if self.after_norm is not None:
            q = self.after_norm(q)

        if embed_len is not None:
            embed_len = embed_len // self.window_size
        # N x T' x D
        out = q.view(bsz, slen, -1)

        return out, embed_len


class AudioEmbedding(nn.Module):
    """Image embedding."""

1111
    def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
1112
1113
1114
        super().__init__()
        self.config = config
        # n_embed or hidden_size for text LM
1115
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
1116
1117
1118
1119
1120
1121
1122
1123

        # self.wte = nn.Embedding(config.vocab_size, hidden_size)

        audio_dim_out = (
            None  # Set this variable according to the actual audio processor
        )
        self.layer_idx = -2

1124
1125
1126
1127
        if (
            isinstance(config.audio_processor, dict)
            and config.audio_processor.get("name", None) == "cascades"
        ):
1128
1129
1130
1131
1132
1133
1134
1135
1136
            encoder_config = config.audio_processor.get("config", None)
            assert encoder_config is not None
            self.encoder = ConformerEncoder(**encoder_config)

            audio_dim_out = encoder_config["attention_dim"]
            n_mels = encoder_config["input_size"]
        else:
            raise NotImplementedError("")

1137
        assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
1138
1139
1140
        self.audio_dim_out = audio_dim_out
        self.audio_dim_in = n_mels

1141
        self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

        self.downsample_rate = kwargs.get("downsample_rate", 1)

        if kwargs.get("use_qformer", False):
            qformer_config = kwargs.get("qformer_config", {})
            qformer_config["attention_dim"] = audio_dim_out
            self.qformer = WindowQformer(**qformer_config)
        else:
            self.qformer = None

        if kwargs.get("use_conv_downsample", False):
1153
1154
1155
            assert self.qformer is None, (
                "don't support use qformer and conv downsample together"
            )
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
            nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
            default_nemo_conv_settings = {
                "subsampling": "dw_striding",
                "subsampling_factor": self.downsample_rate,
                "feat_in": audio_dim_out,
                "feat_out": audio_dim_out,
                "conv_channels": 256,
                "subsampling_conv_chunking_factor": 1,
                "activation": nn.ReLU(),
                "is_causal": False,
            }
            # Override any of the defaults with the incoming, user settings
            if nemo_conv_settings:
                default_nemo_conv_settings.update(nemo_conv_settings)
                for i in ["subsampling_factor", "feat_in", "feat_out"]:
1171
1172
1173
                    assert i not in nemo_conv_settings, (
                        "{i} should be specified outside of the NeMo dictionary"
                    )
1174

1175
1176
1177
            self.conv_ds = NemoConvSubsampling(
                **default_nemo_conv_settings,
            )
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        else:
            self.conv_ds = None

        projection_cls = kwargs.get("projection_cls", "linear")
        if projection_cls == "linear":
            self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
        elif projection_cls == "mlp":
            # follow llava-v1.5's implementation
            # (do not use image_projection and image_proj_norm)
            dim_projection = hidden_size
            depth = 2
1189
1190
1191
            self.linear_downsample_rate = (
                1 if (self.qformer or self.conv_ds) else self.downsample_rate
            )
1192
            layers = [
1193
                nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
1194
1195
            ]
            for _ in range(1, depth):
1196
                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
1197
1198
1199
            self.audio_projection = nn.Sequential(*layers)
            # NOTE vision-speech tasks use a separate projection layer
            layers = [
1200
                nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
1201
1202
            ]
            for _ in range(1, depth):
1203
                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
1204
1205
1206
            self.audio_projection_for_vision = nn.Sequential(*layers)
        else:
            raise NotImplementedError(
1207
1208
                f"projection_cls = {projection_cls}, not implemented"
            )
1209
1210
1211
1212
1213
1214

        # TODO: audio sequence compression - Qformer
        self.vocab_size = config.vocab_size
        self.input_embeds = None
        self.audio_embed_sizes = None

1215
    def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
1216
1217
        self.input_embeds = input_embeds

1218
    def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
1219
1220
1221
1222
        self.audio_embed_sizes = audio_embed_sizes

    def get_audio_features(
        self,
1223
        input_embeds: torch.Tensor,
1224
        audio_attention_mask: torch.Tensor | None = None,
1225
        audio_projection_mode: str = "speech",
1226
    ) -> torch.Tensor:
1227
1228
1229
1230
        """
        arguments:
            input_embeds: audio features (B, T, D)  B: num audios in a sequence
        """
1231
1232
        if self.freeze_audio_processor:
            with torch.no_grad():
1233
                audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
1234
        else:
1235
            audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263

        if self.qformer is not None:
            audio_features, _ = self.qformer(audio_features, mask=None)

        if self.conv_ds is not None:
            if masks is not None:
                masks = masks.squeeze(1)

            audio_features, masks = self.conv_ds(audio_features, mask=masks)

        if self.linear_downsample_rate != 1:
            bs, seq_len, feat_dim = audio_features.size()
            padding = seq_len % self.linear_downsample_rate
            if padding > 0:
                audio_features = F.pad(
                    audio_features,
                    (0, 0, 0, self.linear_downsample_rate - padding),
                    "constant",
                    0,
                )

            seq_len = audio_features.size(1)
            audio_features = audio_features.view(
                bs,
                seq_len // self.linear_downsample_rate,
                feat_dim * self.linear_downsample_rate,
            )

1264
        if audio_projection_mode == "speech":
1265
            audio_set_tensor = self.audio_projection(audio_features)
1266
        elif audio_projection_mode == "vision":
1267
1268
1269
            audio_set_tensor = self.audio_projection_for_vision(audio_features)
        else:
            raise ValueError(
1270
                f"audio_projection_mode = {audio_projection_mode} not implemented"
1271
1272
1273
1274
1275
1276
            )

        return audio_set_tensor

    def forward(
        self,
1277
        audio_features: torch.Tensor,
1278
        audio_attention_mask: torch.Tensor | None = None,
1279
        audio_projection_mode: str = "speech",
1280
    ) -> torch.Tensor:
1281
1282
        """
        arguments:
1283
            audio_features: audio features (T, D)
1284

1285
1286
        returns:
            audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
1287
        """
1288
1289
1290
1291
1292
1293
        audio_embeds = self.get_audio_features(
            audio_features.unsqueeze(0),
            audio_attention_mask=audio_attention_mask,
            audio_projection_mode=audio_projection_mode,
        )
        return audio_embeds.squeeze(0)