transformers.py 32.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18
from collections.abc import Iterable, Mapping
19
from contextlib import contextmanager
20
from typing import Literal, Optional, Union
21

22
import regex as re
23
24
import torch
from torch import nn
25
26
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
                          PreTrainedModel)
27
28
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
32
33
34
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
35
36
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
37
                                               ReplicatedLinear,
38
39
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
42
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
45
46
47
48
49
50
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputs, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
51
from vllm.sequence import IntermediateTensors
52
from vllm.utils import is_list_of
53

54
55
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
                         SupportsQuant)
56
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
57
58
                    flatten_bn, make_empty_intermediate_tensors_factory,
                    maybe_prefix)
59
60
61
62
63
64
65
66
67
68
69
70

logger = init_logger(__name__)


def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
71
        scaling: Optional[float] = None,
72
        # vLLM kwargs
73
        attention_instances: Optional[dict[Attention]] = None,
74
75
76
77
78
79
80
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
81
    return self_attn.forward(query, key, value), None
82
83
84
85
86


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


87
88
89
90
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


91
def replace_linear_class(
92
93
    linear: nn.Linear, style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig
94
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
95
    """
96
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
97

98
99
100
101
102
103
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
104
105
106
107
108
109
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

110
111
112
113
114
115
116
117
118
119
120
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {
            "gather_output": True
        }),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {
            "input_is_parallel": False
        }),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
121

122
    return vllm_linear_cls(
123
124
125
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
126
        quant_config=quant_config,
127
        return_bias=False,
128
        **vllm_linear_kwargs,
129
130
    )

131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs)

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):

        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch, torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)))
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for torch_function_name, old_torch_function in (
                tensor_constructors_to_patch.items()):
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.model_config.hf_config

    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
        mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {}
        mm_tokens = processor._get_num_multimodal_tokens(
            image_sizes=([height, width], ), **mm_processor_kwargs)
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


class MultiModalDummyInputsBuilder(
        BaseDummyInputsBuilder[MultiModalProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
240
        out_mm_kwargs: MultiModalKwargsItems,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs,
        num_image_patches: torch.Tensor = None,
    ):
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        mm_fields = {
            key: MultiModalFieldConfig.flat_from_sizes("image",
                                                       num_image_patches)
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
            "image", num_image_patches)
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _apply_hf_processor_text_mm(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
281
    ) -> tuple[list[int], BatchFeature, bool]:
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        """
        Apply the HF processor on the prompt text and multi-modal data
        together.

        In addition, return whether prompt replacements have been applied.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
            tok_kwargs=tokenization_kwargs,
        )
        processed_data.update(passthrough_data)

        prompt_ids, = processed_data.pop("input_ids").tolist()
        mm_token_type_ids = processed_data.pop(
            "mm_token_type_ids"
        ) if "mm_token_type_ids" in processed_data else processed_data.pop(
            "token_type_ids")  # for gemma3 only

        return prompt_ids, processed_data, mm_token_type_ids

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
325
326
327
328
329
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

        (prompt_ids, processed_data,
         mm_token_type_ids) = self._apply_hf_processor_text_mm(
             prompt_text=prompt,
             mm_items=mm_items,
             hf_processor_mm_kwargs=hf_processor_mm_kwargs,
             tokenization_kwargs=tokenization_kwargs,
         )

        # HF processor will return `mm_token_type_ids` from which
        # we can infer mm_placeholders. Until then hardcode to make code run
        # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
        mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs
                               or {})
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
            image_sizes=image_sizes, **mm_processor_kwargs)

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool())
                for positions, mm_tokens in zip(chunked_mm_positions,
                                                chunked_mm_tokens)
            ]
            mm_placeholders = {"image": ranges}

        num_image_patches = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
        ) if "num_image_patches" in mm_tokens_per_modality else None
        processed_data['num_image_patches'] = num_image_patches
374
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
375
376
377
378
379
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
                                       num_image_patches),
        )

380
381
        mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
                                        tokenization_kwargs)
382
383
384
385
386
        return MultiModalInputs(
            type="multimodal",
            prompt=prompt,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
387
            mm_hashes=mm_hashes,
388
389
390
391
            mm_placeholders=mm_placeholders,
        )


392
393
394
395
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it
396

397
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
398
399
400
        super().__init__()
        logger.info("Using Transformers backend.")

401
402
403
404
405
406
407
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
        self.quant_config: QuantizationConfig = vllm_config.quant_config
408
409
410
411
412
413

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

414
415
416
        # To be updated in child classes for use in `load_weights`
        self.skip_prefixes: Optional[list[str]] = None

417
418
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        # TODO: @raushan, use the public `model.set_attn_implementation()`
419
        # method once its checks are fixed in Transformers.
420
        self.text_config._attn_implementation = "vllm"
421
        with init_on_device_without_buffers("meta"):
422
            self.model: PreTrainedModel = AutoModel.from_config(
423
424
425
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
426
            )
427

428
429
430
431
432
433
434
        self.pipeline_parallel()
        self.tensor_parallel()

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
435
436
437
438
                    self.text_config.vocab_size,
                    self.text_config.hidden_size,
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
439
440
441
442
443
                ))

        # Attention layers
        self.attention_instances = self.create_attention_instances()

444
        # Initialize any parameters that have not had their modules replaced
445
446
        self.init_parameters(self.model)

447
        self.make_empty_intermediate_tensors = (
448
449
            make_empty_intermediate_tensors_factory(
                ["hidden_states"], self.text_config.hidden_size))
450
451
452
453
454
455
456

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        if not self.model.supports_pp_plan:
            raise ValueError(
                f"{type(self.model)} does not support pipeline parallel yet!")

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
480
481
482
            if self.pp_group.is_first_rank or (
                    self.text_config.tie_word_embeddings
                    and self.pp_group.is_last_rank):
483
484
485
486
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
487
488
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size)
489
490
491
492
493
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
494
            layers[i] = PPMissingLayer()
495
496
497
498
499
500
501
502
503
504
505
506

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def tensor_parallel(self):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
507
508
509
510
511
        # Look for tp plans in all of the PreTrainedModels found in self.model
        is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
        supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
        pretrained_models = filter(is_pretrained_model, self.model.modules())
        models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
512

513
        if not any(models_with_tp_plan) and self.tp_size > 1:
514
515
516
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel yet!")

517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        def _tensor_parallel(module: nn.Module,
                             prefix: str = "",
                             tp_plan=None):
            tp_plan = tp_plan or {}

            # If the current module is a PreTrainedModel, set the tp_plan for
            # all of its children
            if isinstance(module, PreTrainedModel):
                tp_plan = module.config.base_model_tp_plan or {}
                tp_plan = {
                    maybe_prefix(prefix, k): v
                    for k, v in tp_plan.items()
                }

            # Some weight loaders expect linear layers to inherit from vLLM's
            # LinearBase class, so we set a default style which causes any
            # unspecified linear layers to be replaced with ReplicatedLinear
534
535
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
536
537
538
539
540
541
542
543
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
                    style = tp_plan.get(pattern, "replicate")
                    new_module = replace_linear_class(child_module, style,
                                                      self.quant_config)
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
544
                else:
545
546
547
                    _tensor_parallel(child_module,
                                     prefix=qual_name,
                                     tp_plan=tp_plan)
548
549
550
551
552
553
554
555
556
557
558

        _tensor_parallel(self.model)

    def create_attention_instances(self) -> dict[int, Attention]:
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
559
        start, end = get_pp_indices(self.text_config.num_hidden_layers,
560
                                    self.pp_rank, self.pp_size)
561
562
563
564

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
565
566
567
568
            per_layer_sliding_window = None
            if (hasattr(self.config, "layer_types")
                    and self.config.layer_types[i] == "sliding_attention"):
                per_layer_sliding_window = self.config.sliding_window
569
570

            attention_instances[i] = Attention(
571
572
                num_heads=num_heads,
                head_size=head_size,
573
574
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
575
576
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
577
                cache_config=self.cache_config,
578
                quant_config=self.quant_config,
579
                per_layer_sliding_window=per_layer_sliding_window,
580
                prefix=f"{i}.attn")
581
        return attention_instances
582

583
584
585
586
587
588
589
590
591
592
593
594
595
596
    def init_parameters(self, module: nn.Module):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(param.data,
597
                                     dtype=self.model_config.dtype,
598
599
600
601
602
                                     device=self.device_config.device))
                setattr(module, name, new_param)
        for child in module.children():
            self.init_parameters(child)

603
604
    def forward(
        self,
605
        input_ids: Optional[torch.Tensor],
606
607
608
609
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
610
611
612
613
614
615
616
617
618
619
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

620
621
622
623
624
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

625
626
627
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
628
            use_cache=False,
629
            position_ids=position_ids,
630
631
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
632
633
634
635
636

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
637

638
639
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
640
641
        loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
642
643


644
645
646
647
648
649
650
651
652
653
654
655
@support_torch_compile
class TransformersModel(TransformersBase):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Add `model.` prefix for base model checkpoints
            "": "model.",
            # Remove `model.` from places it should not be
            "model.model.": "model.",
            "model.score": "score",
        })


656
@support_torch_compile
657
class TransformersForCausalLM(TransformersBase):
658
659

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
660
        super().__init__(vllm_config=vllm_config, prefix=prefix)
661

662
663
664
665
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
            self.skip_prefixes = ["lm_head."]
666
667

        if get_pp_group().is_last_rank:
668
            self.unpadded_vocab_size = self.text_config.vocab_size
669
            self.lm_head = ParallelLMHead(
670
671
672
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
673
674
                prefix=maybe_prefix(prefix, "lm_head"),
            )
675
            if self.text_config.tie_word_embeddings:
676
677
678
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

679
680
681
682
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size,
                logit_scale)
683
684
685
686
687
688
689
690
691
692
693
694
        else:
            self.lm_head = PPMissingLayer()

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

695

696
697
698
699
700
701
702
703
704
705
706
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
    """Flatten until a list of tensors can be concatenated then do concat"""

    def _can_concat(x: list[torch.Tensor]):
        return len(set(map(lambda _x: _x.shape[1:], x))) == 1

    if _can_concat(x):
        return torch.concat(x)
    return flatten_and_concat(flatten_bn(x))


707
708
709
710
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
    dummy_inputs=MultiModalDummyInputsBuilder)
711
712
713
714
715
716
717
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
    })  # set `positions` to last dim to support Qwen-mrope
718
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
        })

742
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
743
        super().__init__(vllm_config=vllm_config, prefix=prefix)
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        if inputs_embeds is None:
            multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
            if multimodal_embeds is not None:
                inputs_embeds = self.get_input_embeddings(
                    input_ids, multimodal_embeds)
                input_ids = None

765
766
        model_output = super().forward(input_ids, positions,
                                       intermediate_tensors, inputs_embeds)
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        return model_output

    def get_multimodal_embeddings(self, **kwargs):
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
            "image_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if image_embeds is not None:
            return image_embeds

        if pixel_values is None and image_embeds is None:
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
            if isinstance(pixel_values, torch.Tensor):
                pixel_values = flatten_bn(pixel_values).to(self.dtype)
            elif is_list_of(pixel_values, torch.Tensor):
786
                pixel_values = flatten_and_concat(pixel_values).to(self.dtype)
787
788
789
790
791
792
793
794
            else:
                raise ValueError(
                    f"Unsupported pixel_values type {type(pixel_values)}. "
                    "Expected `torch.Tensor` or list of `torch.Tensor`.")

            if isinstance(num_image_patches, list):
                num_image_patches = torch.cat(num_image_patches)

795
            vision_embeddings = self.model.get_image_features(
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
                pixel_values,
                **{
                    k: v.flatten(0, 1)
                    for k, v in kwargs.items()
                },
            )

            if isinstance(vision_embeddings, torch.Tensor):
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
                    vision_embeddings,
                    num_image_patches.flatten().tolist())
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings=None,
    ) -> torch.Tensor:
825
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
826
827
828
829
830
831
832
833
834
        if (multimodal_embeddings is not None
                and len(multimodal_embeddings) != 0):
            mask = (input_ids == self.config.image_token_id)
            mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
            multimodal_embeddings = torch.cat(multimodal_embeddings)

            inputs_embeds = inputs_embeds.masked_scatter(
                mask, multimodal_embeddings)
        return inputs_embeds