aria.py 25.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
4
5
6

import torch
import torch.nn as nn
7
8
9
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
10
11
12
13
14
15
16
17

from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
19
from vllm.model_executor.layers.sampler import (SamplerOutput,
                                                SamplingMetadata, get_sampler)
20
21
22
23
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
25
from vllm.multimodal.parse import MultiModalDataItems
26
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27
28
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
29
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
30
31
from vllm.sequence import IntermediateTensors

32
# yapf: disable
33
from .idefics2_vision_model import Idefics2VisionConfig
34
35
36
from .idefics2_vision_model import (
    Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
37
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
38
39
40
41
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    is_pp_missing_parameter, maybe_prefix,
                    merge_multimodal_embeddings)
42
43
44
45
46
47


class AriaImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    pixel_mask: Optional[torch.Tensor]
    """
48
    Shape:
49
50
51
52
53
        pixel_values: `(batch_size * num_images, num_channels, height, width)`
        pixel_mask: `(batch_size * num_images, height, width)`
    """


54
55
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    def __init__(
        self,
        config: Idefics2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(config, quant_config, prefix)
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:

            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


102
class AriaProjectorMLP(nn.Module):
103
104
105

    def __init__(
        self,
106
107
108
        in_features: int,
        hidden_features: int,
        output_dim: int,
109
110
111
    ) -> None:
        super().__init__()

112
113
114
115
116
117
        self.linear_in = ColumnParallelLinear(in_features,
                                              hidden_features,
                                              bias=False)
        self.linear_out = RowParallelLinear(hidden_features,
                                            output_dim,
                                            bias=False)
118
119
        self.act = get_act_fn("gelu_new")

120
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
        patch_to_query_dict (dict): Maps patch numbers to their corresponding
        query numbers,
            e.g., {1225: 128, 4900: 256}. This allows for different query sizes
            based on image resolution.
137
138
139
140
141
        embed_dim (int): Embedding dimension.
        num_heads (int): Number of attention heads.
        kv_dim (int): Dimension of key and value.
        ff_dim (int): Hidden dimension of the feed-forward network.
        output_dim (int): Output dimension.
142
143
144
145
146
147
        norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

148
    def __init__(self, config: AriaConfig) -> None:
149
        super().__init__()
150
151
152
153
154
155
156

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
157
158

        self.query = nn.Parameter(
159
160
            torch.empty(config.max_value_projector_patch_to_query_dict,
                        self.in_features))
161

162
        self.cross_attn = AriaCrossAttention(config)
163

164
165
166
167
        self.layer_norm = nn.LayerNorm(self.in_features)
        self.feed_forward = AriaProjectorMLP(self.in_features,
                                             self.hidden_features,
                                             self.output_dim)
168

169
170
171
172
173
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
174
        batch_size, num_patches = x.shape[0], x.shape[1]
175

176
177
178
179
        if num_patches not in self.patch_to_query_dict:
            raise KeyError(f"Number of patches {num_patches} not found in "
                           "patch_to_query_dict amongst possible values "
                           f"{self.patch_to_query_dict.keys()}.")
180

181
182
183
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
184
185
186
187
188
189
190

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

191
        out = self.feed_forward(self.layer_norm(attention_out))
192
193
194
195
196
197
198

        return out


class AriaFusedMoE(FusedMoE):

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
199
                      shard_id: str) -> None:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
        if shard_id == 'w13':
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
                up_and_gate = torch.cat([up_current_rank, gate_current_rank],
                                        dim=-1).transpose(1, 2)
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
        elif shard_id == 'w2':
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
                down_current_rank = loaded_weight.chunk(self.tp_size,
                                                        dim=1)[tp_rank]
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


228
class AriaTextMoELayer(nn.Module):
229
230
231
232
233
234
235
236
237
238
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
239
        config: AriaTextConfig,
240
        quant_config: Optional[QuantizationConfig],
241
        prefix: str = "",
242
243
244
245
246
247
248
249
250
251
252
253
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
            torch.empty(
                (self.config.moe_num_experts, self.config.hidden_size)))

        self.experts = AriaFusedMoE(
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
254
            intermediate_size=config.intermediate_size,
255
256
            quant_config=quant_config,
            reduce_results=True,
257
            prefix=f"{prefix}.experts",
258
259
260
        )
        self.shared_experts = LlamaMLP(
            config.hidden_size,
261
            config.intermediate_size * config.moe_num_shared_experts,
262
263
            "silu",
            quant_config=quant_config,
264
            bias=config.mlp_bias,
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
            hidden_states (torch.Tensor): Input tensor of shape (batch_size,
            sequence_length, hidden_size).

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

        router_output = torch.nn.functional.linear(hidden_states,
                                                   self.router_weight)

282
283
        hidden_states_copy = hidden_states.clone()
        # NOTE: hidden_states will be modified inplace by `FusedMoE`
284
        sparse_expert_output = self.experts(hidden_states, router_output)
285
        shared_expert_output = self.shared_experts(hidden_states_copy)
286
287
288
289

        return sparse_expert_output + shared_expert_output


290
class AriaTextDecoderLayer(LlamaDecoderLayer):
291
292
293
294
295
296
297
298
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

    def __init__(
        self,
299
        config: AriaTextConfig,
300
301
302
303
304
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(config, cache_config, quant_config, prefix)
305
306
307
        self.mlp = AriaTextMoELayer(config,
                                    quant_config=quant_config,
                                    prefix=f"{prefix}.mlp")
308
309


310
class AriaTextModel(LlamaModel, SupportsQuant):
311
312
313
314
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
315
316
317
318
319
320
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
321
322

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323
324
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
325
                         layer_type=AriaTextDecoderLayer)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
            ("experts.w13_weight", "experts.fc1.weight", 'w13'),
            ("experts.w2_weight", "experts.fc2.weight", 'w2'),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
351
352
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
353
                # Loading kv cache quantization scales
354
355
356
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
357
358
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


397
class AriaProcessingInfo(BaseProcessingInfo):
398

399
    def get_hf_config(self):
400
        return self.ctx.get_hf_config(AriaConfig)
401

402
    def get_vision_config(self):
403
        return self.get_hf_config().vision_config
404

405
406
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
407

408
409
410
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

411
412
413
414
415
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
416
417
418
419
420
421
422
423
        return {"image": self.get_num_image_tokens()}

    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
424

425
426
427
428
429
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
430
        vision_config = self.info.get_vision_config()
431
432
433
434
435
436
437
438
439
440
441

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }

442
        hf_processor = self.info.get_hf_processor()
443
        image_token: str = hf_processor.tokenizer.image_token  # type: ignore
444
445
446
447
448
449
450

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


451
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
452

453
454
455
456
457
458
459
460
461
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
462

463
    def _get_prompt_updates(
464
465
466
467
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
468
    ) -> Sequence[PromptUpdate]:
469
        hf_config = self.info.get_hf_config()
470
471
        image_token_id = hf_config.image_token_index

472
        num_image_tokens = self.info.get_num_image_tokens()
473
474
475
476
477

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
478
                replacement=[image_token_id] * num_image_tokens,
479
480
            )
        ]
481
482


483
484
485
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
                                        info=AriaProcessingInfo,
                                        dummy_inputs=AriaDummyInputsBuilder)
486
487
488
489
490
491
492
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
493
494
495
496
497
498
499
500
501
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
502
503
504
505
506
507
508
509
510
511
512

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
513
        self.vision_tower = AriaVisionTransformer(
514
515
516
517
518
            config.vision_config,
            quant_config,
            prefix=f"{prefix}.vision_tower",
        )
        self.multi_modal_projector = AriaProjector(config)
519
        self.vocab_size = config.text_config.vocab_size
520
        self.language_model = AriaTextModel(
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model.model"),
        )
        self.pad_token_id = (self.config.pad_token_id
                             if self.config.pad_token_id is not None else -1)
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config,
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.vocab_size, logit_scale)
536
        self.sampler = get_sampler()
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557

    def _validate_image_sizes(
            self, images: List[torch.Tensor]) -> List[torch.Tensor]:
        if not all(img.shape == images[0].shape for img in images):
            raise ValueError("All images must be the same size")
        return images

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

        pixel_values = self._validate_image_sizes(pixel_values)
        pixel_values = flatten_bn(pixel_values, concat=True)
558

559
        if pixel_mask is not None:
560
561
562
563
            if not isinstance(pixel_mask, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel mask. "
                                 f"Got type: {type(pixel_mask)}")

564
565
566
567
568
569
570
            pixel_mask = flatten_bn(pixel_mask, concat=True)

        return AriaImagePixelInputs(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
        )

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    def _create_patch_attention_mask(
            self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

587
588
589
590
591
592
593
594
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert self.vision_tower is not None

        pixel_values = image_input['pixel_values']
        pixel_mask = image_input['pixel_mask']

595
596
597
598
599
600
601
602
603
604
605
606
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
607

608
    def get_multimodal_embeddings(
609
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
610
611
612
613
614
615
616
617
618
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
619
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_index)
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
            # always pass the input via `inputs_embeds`
            # to make sure the computation graph is consistent
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings)
            input_ids = None

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
669
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)