granite_speech.py 33.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2025 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
omahs's avatar
omahs committed
25
"""Inference-only IBM Granite speech model."""
26

27
import math
28
from collections.abc import Iterable, Mapping
29
from typing import Annotated, Literal, cast
30

31
import numpy as np
32
33
34
35
36
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BatchFeature, PretrainedConfig

37
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
38
from vllm.config.multimodal import BaseDummyOptions
39
from vllm.inputs.data import PromptType
40
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
41
42
43
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
60
61
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
62
63
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config
64
from vllm.utils.tensor_schema import TensorSchema, TensorShape
65
66

from .blip2 import Blip2QFormerModel
67
68
69
70
71
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
72
    SupportsTranscription,
73
)
74
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
75

76
77
78
79
80
81
82
83
84
85
86
87
# NOTE lang support is based on what is written here:
# https://huggingface.co/ibm-granite/granite-speech-3.3-2b
# Though this may vary from model to model, and also many langs
# work pretty well with zero shot.
ISO639_1_SUPPORTED_LANGS = {
    "en": "English",
    "fr": "French",
    "de": "German",
    "pt": "Portuguese",
    "es": "Spanish",
}

88
89

### Audio Input
90
91
92
class GraniteSpeechAudioInputs(TensorSchema):
    """
    Audio input features for Granite Speech model.
93

94
95
    Dimensions:
        - b: Batch size
96
97
        - fi: Number of input features from the Mel spectrogram.
        - fo: Number of output features, i.e. the embedding size.
98
99
        - 160: Fixed feature dimension for Mel spectrogram features
    """
100

101
    input_features: Annotated[torch.Tensor, TensorShape("b", "fi", 160)]
102
    """Audio input features."""
103

104
    input_features_mask: Annotated[torch.Tensor, TensorShape("b", "fo")]
105
    """Mask for variable length audio features."""
106

107
108
    audio_embed_sizes: Annotated[list[int], TensorShape("b")]
    """List of audio embedding sizes for each item in batch."""
109
110
111


class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
112
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        return {"audio": 1}

    # There is no limit to the maximum number of audio tokens that can be
    # encoded as features; we pick ~5000 as a number that is probably higher
    # than we would expect to encounter. The sequence of length
    # get_max_audio_len() produces get_max_audio_tokens().
    def get_max_audio_tokens(self):
        return 5001

    def get_max_audio_len(self):
        return 8000000


### Input Processing  & Multimodal utils
class GraniteSpeechMultiModalProcessor(
128
129
    BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]
):
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_hf_processor().audio_processor
        sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
        return MultiModalDataParser(target_sr=sampling_rate)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            input_features=MultiModalFieldConfig.batched("audio"),
            audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
149
        out_mm_kwargs: MultiModalKwargsItems,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    ) -> list[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        feature_extractor = processor.audio_processor
        vocab = tokenizer.get_vocab()

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|audio|>")
        audio_token_id = vocab[audio_token]

        def get_replacement(item_idx: int):
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio = audios.get(item_idx)
            audio_length = audio.shape[-1]
            num_projector_features = feature_extractor._get_num_audio_features(
165
166
                [audio_length]
            )[0]
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            return [audio_token_id] * num_projector_features

        return [
            PromptReplacement(
                modality="audio",
                target=[audio_token_id],
                replacement=get_replacement,
            )
        ]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
182
        tok_kwargs: Mapping[str, object],
183
184
185
186
187
188
189
190
191
192
193
194
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])

        if audios:
            # GraniteSpeechFeatureExtractor accepts "audio"
            mm_data["audio"] = audios

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
195
            tok_kwargs=tok_kwargs,
196
197
198
199
200
201
        )

        if "audio" in mm_data:
            # Calculate the number of audio tokens per entry in the batch;
            # This is used to split the batch back out after padding.
            audio_token_index = self.info.get_hf_config().audio_token_index
202
            processed_outputs["audio_embed_sizes"] = (
203
204
                processed_outputs["input_ids"] == audio_token_index
            ).sum(-1)
205
206
207
208
209

        return processed_outputs


class GraniteSpeechDummyInputsBuilder(
210
211
    BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]
):
212
213
214
215
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
216
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
217
218
    ) -> MultiModalDataDict:
        num_audios = mm_counts.get("audio", 0)
219
220
        audio_overrides = mm_options.get("audio") if mm_options else None

221
        return {
222
            "audio": self._get_dummy_audios(
223
224
                length=self.info.get_max_audio_len(),
                num_audios=num_audios,
225
                overrides=audio_overrides,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            )
        }

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        hf_processor = self.info.get_hf_processor()
        audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
        return audio_token * num_audios


### QFormer Projector
class GraniteSpeechEncoderProjector(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: CacheConfig,
242
        quant_config: QuantizationConfig | None = None,
243
244
245
246
247
248
249
250
251
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.projector_config.hidden_size
        self.downsample_rate = config.downsample_rate
        self.window_size = config.window_size
        self.num_queries = config.window_size // config.downsample_rate

        self.query = nn.Parameter(
252
253
            torch.zeros(1, self.num_queries, config.projector_config.hidden_size)
        )
254
255
256
257
258
259
260
261
262
263

        # NOTE - this is implemented generically in transformers,
        # but for now we create the QFormer model directly since
        # all existing models use this for the projector.
        self.qformer = Blip2QFormerModel(
            config.projector_config,
            quant_config=quant_config,
            cache_config=cache_config,
            prefix=f"{prefix}.qformer",
        )
264
265
266
        self.linear = nn.Linear(
            config.projector_config.hidden_size, config.text_config.hidden_size
        )
267
268
269
270
271

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, dim = hidden_states.size()
        nblocks = math.ceil(seq_len / self.window_size)
        pad = nblocks * self.window_size - seq_len
272
273
        hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
        hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
274
275
276
277
278
279
280
281
282
283
284

        last_hidden_state = self.qformer(
            query_embeds=self.query.data,
            encoder_hidden_states=hidden_states,
        )

        query_proj = self.linear(
            last_hidden_state.view(
                batch_size,
                nblocks * self.window_size // self.downsample_rate,
                -1,
285
286
            )
        )
287
288
289
290
291
292
293
294
295
        return query_proj


# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
# NOTE - it would be nice to see if we can align this with other models using
# conformer in vLLM, e.g., phi4mm audio.
class GraniteSpeechConformerFeedForward(nn.Module):
    """Feedforward module for conformer encoder blocks."""

296
297
298
    def __init__(
        self,
        config: PretrainedConfig,
299
        quant_config: QuantizationConfig | None = None,
300
301
        prefix: str = "",
    ):
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        super().__init__()
        self.pre_norm = nn.LayerNorm(config.hidden_dim)

        self.up_proj = ColumnParallelLinear(
            input_size=config.hidden_dim,
            output_size=config.hidden_dim * config.feedforward_mult,
            quant_config=quant_config,
            prefix=f"{prefix}.up_proj",
        )
        self.silu = nn.SiLU()

        self.down_proj = RowParallelLinear(
            input_size=config.hidden_dim * config.feedforward_mult,
            output_size=config.hidden_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states)
        hidden_states, _ = self.up_proj(hidden_states)
        hidden_states = self.silu(hidden_states)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states


class GraniteSpeechConformerAttention(nn.Module):
    """Attention for conformer blocks using Shaw's relative positional
    embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
    for more details.
    """

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()

        inner_dim = config.dim_head * config.num_heads
        self.max_pos_emb = config.max_pos_emb
        self.context_size = config.context_size
        self.num_heads = config.num_heads
        self.dim_head = config.dim_head
        self.scale = self.dim_head**-0.5
        self.pre_norm = nn.LayerNorm(config.hidden_dim)
        self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, config.hidden_dim)
347
        self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
348
349
350
351
352
353

        if self.context_size <= 0 or self.context_size > self.max_pos_emb:
            raise ValueError(
                "Context size is either less than 0 or exceeds the max_pos_emb"
            )

354
355
356
    def forward(
        self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
    ) -> torch.Tensor:
357
358
359
360
361
362
363
364
        hidden_states = self.pre_norm(hidden_states)
        bsz, num_features, _ = hidden_states.shape

        num_blocks = math.ceil(num_features / self.context_size)
        remainder = num_features % self.context_size
        if remainder > 0:
            # right padding to reach block size
            hidden_states = torch.nn.functional.pad(
365
366
                hidden_states, (0, 0, 0, self.context_size - remainder)
            )
367
368
369
370
371
372

        # NOTE: would be nice to try to use qkvparallellinear
        # here for this block attention implementation if possible
        query_states = self.to_q(hidden_states)
        key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)

373
374
375
376
377
378
379
380
381
        query_states = query_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)
        key_states = key_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)
        value_states = value_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)
382
383
384
385

        # shaw's relative positional embedding
        dist = attention_dists.to(hidden_states.device)
        rel_pos_emb = self.rel_pos_emb(dist)
386
387
388
389
390
        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
        pos_attn = (
            torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
            * self.scale
        )
391
392
393

        if remainder > 0:
            # masked attention in the extended block
394
395
396
397
398
399
            mask = torch.ones(
                self.context_size,
                self.context_size,
                dtype=bool,
                device=hidden_states.device,
            )
400
401
402
403
            mask[:remainder, :remainder] = 0
            mask_value = -torch.finfo(pos_attn.dtype).max
            pos_attn[:, -1, :].masked_fill_(mask, mask_value)

404
405
406
407
408
409
410
411
        with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
            out = F.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=pos_attn,
                scale=self.scale,
            )
412
413
414
415
416
417
418
        out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
        return self.to_out(out[:, :num_features, :])


class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
    """Wrapper for padded 1D pointwise convolution."""

419
    def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""):
420
421
422
423
424
425
        super().__init__()
        # Padding for the 1D conv is symmetric or close (i.e., offset by one).
        pad = kernel_size // 2
        pad_offset = (kernel_size + 1) % 2
        self.padding = (pad, pad - pad_offset)

426
427
428
        self.conv = nn.Conv1d(
            chan_in, chan_out, kernel_size, groups=chan_in, bias=False
        )
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = F.pad(hidden_states, self.padding)
        return self.conv(hidden_states)


class GraniteSpeechConformerConvModule(nn.Module):
    """Conformer conv module consisting of several 1D/depthwise 1D
    convolutional layers.
    """

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()
        inner_dim = config.hidden_dim * config.conv_expansion_factor

        self.norm = nn.LayerNorm(config.hidden_dim)
        self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
        self.glu = nn.GLU(dim=1)
        self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
            inner_dim,
            inner_dim,
            kernel_size=config.conv_kernel_size,
            prefix=f"{prefix}.depth_conv",
        )
        self.silu = nn.SiLU()
        self.batch_norm = nn.BatchNorm1d(inner_dim)
        self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.norm(hidden_states)
        hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
        hidden_states = self.glu(hidden_states)
        hidden_states = self.depth_conv(hidden_states)
        hidden_states = self.silu(self.batch_norm(hidden_states))
        hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
        return hidden_states


class GraniteSpeechConformerBlock(nn.Module):
    """Conformer block, consisting largely of linear layers,
    attention, and convolutional layers."""

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()
473
474
475
476
        self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1")
        self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn")
        self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv")
        self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2")
477
478
        self.post_norm = nn.LayerNorm(config.hidden_dim)

479
480
481
    def forward(
        self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
    ) -> torch.Tensor:
482
        hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
483
484
485
        hidden_states = (
            self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
        )
486
487
488
489
490
491
492
493
494
        hidden_states = self.conv(hidden_states) + hidden_states
        hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
        hidden_states = self.post_norm(hidden_states)
        return hidden_states


class GraniteSpeechCTCEncoder(nn.Module):
    """CTC Encoder comprising conformer blocks and additional linear layers."""

495
496
497
498
    def __init__(
        self,
        config: PretrainedConfig,
        prefix: str,
499
        quant_config: QuantizationConfig | None = None,
500
    ):
501
502
503
504
505
506
        super().__init__()
        self.config = config

        # Precompute clamped relative positional encoding distances
        seq = torch.arange(config.context_size)
        relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        self.attention_dists = (
            torch.clamp(relpos_dist, -config.context_size, config.context_size)
            + config.max_pos_emb
        )

        self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
        self.layers = nn.ModuleList(
            [
                GraniteSpeechConformerBlock(
                    config,
                    prefix=f"{prefix}.layers.{idx}",
                )
                for idx in range(config.num_layers)
            ]
        )
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

        self.out = ColumnParallelLinear(
            input_size=config.hidden_dim,
            output_size=config.output_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out",
        )

        self.out_mid = RowParallelLinear(
            input_size=config.output_dim,
            output_size=config.hidden_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_mid",
        )
        self.softmax = nn.Softmax(dim=-1)
        self.num_layers = config.num_layers

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.input_linear(hidden_states)
        for idx, layer in enumerate(self.layers, start=1):
544
            hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
545
546
547
548
549
550
551
552
553
554
555
556
557

            if idx == self.num_layers // 2:
                hidden_states_mid = hidden_states.clone()
                hidden_states_mid, _ = self.out(hidden_states_mid)
                hidden_states_mid = self.softmax(hidden_states_mid)
                hidden_states_mid, _ = self.out_mid(hidden_states_mid)
                hidden_states += hidden_states_mid
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    GraniteSpeechMultiModalProcessor,
    info=GraniteSpeechMultiModalProcessingInfo,
558
559
    dummy_inputs=GraniteSpeechDummyInputsBuilder,
)
560
class GraniteSpeechForConditionalGeneration(
561
562
563
564
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    SupportsLoRA,
565
    SupportsTranscription,
566
):
567
    supported_languages = ISO639_1_SUPPORTED_LANGS
568
569
570
571
572
573
574
575
576
577
578
579
580

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

581
    @classmethod
582
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
583
584
585
586
587
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

588
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config

        self.config = config
        self.quant_config = quant_config
        self.cache_config = cache_config

        # The language model is typically a Granite LLM
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        # Conformer encoder
        self.encoder = GraniteSpeechCTCEncoder(
            config=config.encoder_config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
        )

        # Blip2 QFormer
        self.projector = GraniteSpeechEncoderProjector(
            config=config,
            quant_config=quant_config,
            cache_config=cache_config,
            prefix=f"{prefix}.projector",
        )

        self.make_empty_intermediate_tensors = (
621
622
            self.language_model.make_empty_intermediate_tensors
        )
623
624
625
626

    def _parse_and_validate_audio_input(
        self,
        **kwargs: object,
627
    ) -> GraniteSpeechAudioInputs | None:
628
629
630
        input_features = kwargs.pop("input_features", None)
        input_features_mask = kwargs.pop("input_features_mask", None)
        audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
631

632
633
634
635
636
637
638
639
        if input_features is None:
            return None

        # If we have a batch of variable feature length audio clips, we need
        # to mask the features; usually we would get an input_features_mask
        # from the processor, but we handle rebuilding it here since
        # vLLM generally processes everything independently + batches.
        if input_features_mask is None:
640
            input_features_mask = self._build_input_features_mask(audio_embed_sizes)
641
642

        if not isinstance(input_features, (torch.Tensor, list)):
643
644
645
646
            raise ValueError(
                "Incorrect type of audio input features. "
                f"Got type: {type(input_features)}"
            )
647
648

        if input_features_mask is not None and not isinstance(
649
650
651
652
653
654
            input_features_mask, torch.Tensor
        ):
            raise ValueError(
                "Incorrect type of audio input features mask. "
                f"Got type: {type(input_features_mask)}"
            )
655
656
657
658
659
660
661
662
663
664
665
666

        if isinstance(input_features, torch.Tensor):
            # Granite speech currently only allows one audio token per instance
            # and features are already unsqueezed in the processor, so one
            # instance will have shape [1, {num_features}, 160]. As such,
            # input features will usually be of shape
            # [bsz, 1, num_features, 160], which we squeeze to be 3D here.
            if len(input_features.shape) == 4:
                input_features = input_features.squeeze(1)
            if len(input_features.shape) != 3:
                raise ValueError(
                    "Squeezed input features should be 3D but are of shape "
667
668
669
                    f"{input_features.shape}"
                )
            input_features = input_features.to(self.encoder.input_linear.weight.dtype)
670
671
672
673
674
675

        else:
            # Otherwise we have a list of tensors, which are almost certainly
            # differing in their respective numbers of audio features;
            # stack them into a 3D tensor of size [bsz, most_num_features, 160].
            input_features = self._pad_and_stack_input_features(
676
677
                input_features,
            ).to(self.encoder.input_linear.weight.dtype)
678
679
680
681
682
683
684
685
686
687
688
689

        return GraniteSpeechAudioInputs(
            input_features=input_features,
            input_features_mask=input_features_mask,
            audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
        )

    def _build_input_features_mask(
        self,
        audio_embed_sizes: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate the input features mask, which will generally be used
omahs's avatar
omahs committed
690
        to mask the padded features for all entries in the batch except
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        for those with the most audio features.

        Args:
            audio_embed_sizes: torch.Tensor
                Tensor of num features in each seq in the batch.
        Returns:
            torch.Tensor: Mask of shape (bsz, num_features) to be applied to
            the audio features prior to splitting the audio embeddings.
        """
        most_audio_features = torch.max(audio_embed_sizes).item()
        mask_indices = torch.arange(
            most_audio_features,
            device=audio_embed_sizes.device,
        ).view(1, -1)
        input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
        return input_features_mask

    def _pad_and_stack_input_features(
        self,
        input_features: list[torch.Tensor],
    ) -> torch.Tensor:
        """Given a list of input features of varying length, pad them to the
        same length and stack them into a torch.Tensor.

        NOTE: Usually, padding is done in the input processor/feature extractor
        and zero padded prior to the computation of the Mel features; the
        resulting values are only constant within a batch and generally nonzero
        (i.e., slightly negative nums); we should validate that this is okay
        since we don't use a feature attention mask, but the more important
        thing is that we apply the input_features_mask with variable len
        batches.

        Args:
            input_features: list[torch.Tensor]
                Input features to be coerced into a tensor.
        Returns:
            torch.Tensor: Tensor of shape [bsz, num_features, 160], where
            num_features is the max number of features of any entry in the
            batch.
        """
        # Input features are of shape [bsz, num_features, 160]
        feat_lens = [feats.shape[1] for feats in input_features]
        padding = [max(feat_lens) - length for length in feat_lens]
        # TODO (Alex) - Validate that it's okay to zero pad like this;
        # in transformers we zero pad prior to calculating the speech features,
        # so the value is not zero and is dependent on the batched features.
        padded = [
            torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
            for feats, pad in zip(input_features, padding)
        ]
        stacked_features = torch.cat(padded, dim=0).to(input_features[0])
        return stacked_features

    def _process_audio_input(
        self,
        audio_input: GraniteSpeechAudioInputs,
    ) -> tuple[torch.Tensor]:
        """Compute the audio features to be merged into the LLM embeddings.
749

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        Args:
            audio_input: GraniteSpeechAudioInputs
                Audio inputs object containing Mel features, an input features
                mask, and the (flattened) number of audio tokens per instance.
        Returns:
            tuple[torch.Tensor]: List of length bsz.
        """
        # TODO (Alex) - support embedding inputs
        encoder_embeds = self.encoder(audio_input["input_features"])
        # [bsz, <max feature size>, 4096]
        projected_embeds = self.projector(encoder_embeds)
        # Apply mask on variable length audio features
        masked_embeds = projected_embeds[audio_input["input_features_mask"]]
        # Split variable length features into a tuple
        return torch.split(masked_embeds, audio_input["audio_embed_sizes"])

766
767
768
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

769
    def embed_multimodal(
770
771
        self,
        **kwargs: object,
772
    ) -> MultiModalEmbeddings:
773
774
775
        """Compute the audio embeddings if audio inputs are present."""
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
776
            return []
777

778
779
780
        audio_features = self._process_audio_input(audio_input)
        return audio_features

781
    def embed_input_ids(
782
783
        self,
        input_ids: torch.Tensor,
784
        multimodal_embeddings: MultiModalEmbeddings | None = None,
785
        *,
786
        is_multimodal: torch.Tensor | None = None,
787
788
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
789
    ) -> torch.Tensor:
790
791
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
792
            return super().embed_input_ids(input_ids)
793

794
        return super().embed_input_ids(
795
            input_ids,
796
797
798
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
799
800
801
802
803
804
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
805
806
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
807
        **kwargs: object,
808
    ) -> torch.Tensor | IntermediateTensors:
809
810
811
        if intermediate_tensors is not None:
            inputs_embeds = None

812
813
814
        model_output = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
815
816
817
818
819
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
820
    ) -> torch.Tensor | None:
821
        return self.language_model.compute_logits(hidden_states)
822
823
824

    def load_weights(
        self,
825
826
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
827
828
829
830
831
832
833
834
835
836
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix in multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="projector",
            tower_model="encoder",
        )
837
838
839
840
841
842

    ### Support for speech-to-text Transcription
    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
843
        model_config: ModelConfig,
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        """Get the generation prompt to be used for transcription requests."""
        # Audio placeholders don't use an index, so value doesn't matter
        audio_tok = cls.get_placeholder_str("audio", 0)

        if task_type == "translate":
            full_lang_name_to = cls.supported_languages.get(to_language, to_language)
            user_prompt = f"{audio_tok}translate the speech to {full_lang_name_to}"  # noqa: E501
        elif task_type == "transcribe":
            user_prompt = (
                f"{audio_tok}can you transcribe the speech into a written format?"  # noqa: E501
            )
        else:
            raise ValueError(f"Unsupported task type {task_type}")

864
        tokenizer = cached_tokenizer_from_config(model_config)
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        chat = [dict(role="user", content=user_prompt)]
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True,
        )

        prompt_token_ids = tokenizer.encode(prompt)
        prompt = {
            "prompt_token_ids": prompt_token_ids,
            "multi_modal_data": {"audio": audio},
        }
        return cast(PromptType, prompt)

    # Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
    @classmethod
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
885
        model_config: ModelConfig,
886
887
    ) -> int | None:
        """Get the number of audio tokens for an audio duration in sec."""
888
        processor = cached_processor_from_config(model_config)
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
        proj_win_size = processor.audio_processor.projector_window_size
        ds_rate = processor.audio_processor.projector_downsample_rate
        effective_window_size = proj_win_size // ds_rate

        raw_length = audio_duration_s * stt_config.sample_rate

        # mel sequence length computation
        mel_length = raw_length // hop_length + 1
        # encoder frame takes two mel features
        encoder_length = mel_length // 2
        nblocks = math.ceil(encoder_length / proj_win_size)
        # projector output length
        return nblocks * effective_window_size

    @classmethod
    def get_speech_to_text_config(
906
        cls, model_config: ModelConfig, task_type: str
907
908
909
910
911
912
    ) -> SpeechToTextConfig:
        """Get the stt config for this model."""
        # Default settings are reasonable for this model and we don't currently
        # expose this information in the model configs, but this may change in
        # the future
        return SpeechToTextConfig()