transformers.py 37.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18

19
from collections.abc import Iterable, Mapping
20
from contextlib import contextmanager
21
from pathlib import Path
22
from typing import Literal, Optional, Union
23

24
import regex as re
25
import torch
26
27
import transformers
from packaging.version import Version
28
from torch import nn
29
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
30
31
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

32
from vllm.attention import Attention, AttentionType
33
from vllm.compilation.decorators import support_torch_compile
34
35
36
37
38
39
40
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    VllmConfig,
)
41
from vllm.config.multimodal import BaseDummyOptions
42
from vllm.config.utils import getattr_iter
43
44
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
45
from vllm.logger import init_logger
46
from vllm.model_executor.layers.layernorm import RMSNorm
47
48
49
50
51
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
53
from vllm.model_executor.layers.quantization import QuantizationConfig
54
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
59
60
61
62
63
64
65
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalUUIDDict,
    PlaceholderRange,
)
66
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
67
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
70
from vllm.sequence import IntermediateTensors

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    flatten_bn,
    make_empty_intermediate_tensors_factory,
    maybe_prefix,
)
86
87
88
89

logger = init_logger(__name__)


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


105
def vllm_flash_attention_forward(
106
107
108
109
110
111
112
113
114
115
116
117
    # Transformers args
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor,
    # Transformers kwargs
    scaling: Optional[float] = None,
    # vLLM kwargs
    attention_instances: Optional[dict[Attention]] = None,
    **kwargs,
):
118
119
120
121
122
123
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
124
    return self_attn.forward(query, key, value), None
125
126
127
128
129


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


130
131
132
133
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


151
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
152
153


154
def replace_linear_class(
155
    linear: nn.Linear,
156
157
    style: Style = "replicate",
    quant_config: Optional[QuantizationConfig] = None,
158
159
    *,
    prefix: str = "",
160
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
161
    """
162
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
163

164
    Args:
165
166
167
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
168
    Returns:
169
        The new linear.
170
171
172
    """

    if not isinstance(style, str):
173
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
174

175
176
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
177
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
178
        "rowwise": (RowParallelLinear, {}),
179
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
180
181
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
182

183
    return vllm_linear_cls(
184
185
186
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
187
        quant_config=quant_config,
188
        prefix=prefix,
189
        return_bias=False,
190
        **vllm_linear_kwargs,
191
192
    )

193

194
195
196
197
198
199
200
201
202
203
204
205
206
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    kwargs = {
        "hidden_size": hidden_size,
        "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
207
        "has_weight": getattr(rms_norm, "with_scale", True),
208
209
210
211
212
213
214
215
216
217
218
    }
    if (weight := getattr(rms_norm, "weight", None)) is not None:
        # If weight is a Parameter, get its data tensor
        weight = getattr(weight, "data", weight)
        kwargs["dtype"] = weight.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
241
242
                module._parameters[name].to(device), **kwargs
            )
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
257
258
259
260
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
261
262
263
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
264
265
266
267
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
268
269
270
271
272
273
274
275
276
277
278
279
280
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
281
282
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
283
        mm_tokens = processor._get_num_multimodal_tokens(
284
285
            image_sizes=([height, width],), **mm_processor_kwargs
        )
286
287
288
289
290
291
292
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


293
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
308
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
309
310
311
312
313
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

314
315
        image_overrides = mm_options.get("image") if mm_options else None

316
        return {
317
318
319
320
321
322
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
323
324
325
326
327
328
329
330
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
331
        out_mm_kwargs: MultiModalKwargsItems,
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs,
        num_image_patches: torch.Tensor = None,
    ):
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        mm_fields = {
357
            key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
358
359
360
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
361
362
            "image", num_image_patches
        )
363
364
365
366
367
368
369
370
371
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _apply_hf_processor_text_mm(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
372
    ) -> tuple[list[int], BatchFeature, bool]:
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        """
        Apply the HF processor on the prompt text and multi-modal data
        together.

        In addition, return whether prompt replacements have been applied.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
            tok_kwargs=tokenization_kwargs,
        )
        processed_data.update(passthrough_data)

390
391
392
393
394
395
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
        mm_token_type_ids = (
            processed_data.pop("mm_token_type_ids")
            if "mm_token_type_ids" in processed_data
            else processed_data.pop("token_type_ids")
        )  # for gemma3 only
396
397
398
399
400
401
402
403
404

        return prompt_ids, processed_data, mm_token_type_ids

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
405
        mm_uuids: Optional[MultiModalUUIDDict] = None,
406
407
408
409
410
411
412
413
414
415
416
417
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
418
419
420
421
422
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
423

424
425
426
427
428
429
430
431
        (prompt_ids, processed_data, mm_token_type_ids) = (
            self._apply_hf_processor_text_mm(
                prompt_text=prompt,
                mm_items=mm_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                tokenization_kwargs=tokenization_kwargs,
            )
        )
432
433
434
435
436
437

        # HF processor will return `mm_token_type_ids` from which
        # we can infer mm_placeholders. Until then hardcode to make code run
        # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
438
439
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
440
441
442
443
444
445
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
446
447
            image_sizes=image_sizes, **mm_processor_kwargs
        )
448
449
450
451
452
453
454
455
456
457
458

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
459
460
461
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
                )
                for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
462
463
464
            ]
            mm_placeholders = {"image": ranges}

465
466
467
468
469
470
        num_image_patches = (
            torch.tensor(mm_tokens_per_modality["num_image_patches"])
            if "num_image_patches" in mm_tokens_per_modality
            else None
        )
        processed_data["num_image_patches"] = num_image_patches
471
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
472
            processed_data,
473
474
475
            self._get_mm_fields_config(
                processed_data, hf_processor_mm_kwargs, num_image_patches
            ),
476
        )
477

478
        # Use overrides if provided; fallback to data-dependent hashing.
479
480
481
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
482
483
484
485
486

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
487
            mm_hashes=mm_hashes,
488
489
490
491
            mm_placeholders=mm_placeholders,
        )


492
493
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
494
    embedding_modules = ["embed_tokens"]  # TODO transformers will have a util to get it
495

496
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
497
498
499
        super().__init__()
        logger.info("Using Transformers backend.")

500
501
502
503
504
505
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
506
        self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
507
508
509
510
511
512

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

513
514
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
515
        """Skip loading weights whose qualname starts with these prefixes."""
516
        self.skip_substrs: list[str] = []
517
518
519
520
521
522
523
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

524
525
526
527
        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
528
529
530
                raise NotImplementedError(
                    "Transformers backend does not support MXFP4 quantization yet."
                )
531
532
533
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")
534

535
536
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        # TODO: @raushan, use the public `model.set_attn_implementation()`
537
        # method once its checks are fixed in Transformers.
538
        self.text_config._attn_implementation = "vllm"
539
        with init_on_device_without_buffers("meta"):
540
            self.model: PreTrainedModel = AutoModel.from_config(
541
542
543
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
544
            )
545

546
        # Remove layers not on this pipeline parallel rank
547
        self.pipeline_parallel()
548
549
550
551
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()
552
553
554

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
555
556
557
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
558
559
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
560
                    self.text_config.vocab_size,
561
                    embedding_dim=embedding_dim,
562
563
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
564
565
                )
            )
566

567
        # Initialize any parameters that have not had their modules replaced
568
569
        self.init_parameters(self.model)

570
        # Pipeline parallel intermediate tensors
571
572
573
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.text_config.hidden_size
        )
574
575
576
577
578
579
580

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
581

582
        if not self.model.supports_pp_plan:
583
584
585
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
586
            raise ValueError(
587
588
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
589
590
591
592
593
594
595
596
597
598
599
600

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
601
602
                "in the base model are not supported yet!"
            )
603
        if module_list_idx is None:
604
            raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
605
606
607

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
608
            if self.pp_group.is_first_rank or (
609
610
                self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
            ):
611
612
613
614
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
615
        start_layer, end_layer = get_pp_indices(
616
617
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
        )
618
619
620
621
622
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
623
            layers[i] = PPMissingLayer()
624
625

        # Layers after module list
626
        for name in pp_plan[module_list_idx + 1 :]:
627
628
629
630
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

631
632
633
634
635
636
637
    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
638
        """
639
        tp_plan = self.model.tp_plan
640

641
        if not tp_plan and self.tp_size > 1:
642
643
644
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
645
            raise ValueError(
646
647
                f"{type(self.model)} does not support tensor parallel. {tip}"
            )
648

649
650
651
652
        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
653
            for child_name, child_module in module.named_children():
654
                new_module = child_module
655
                qual_name = maybe_prefix(prefix, child_name)
656
657
658
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
659
660
661
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
662
                    style = tp_plan.get(pattern, "replicate")
663
664
665
                    new_module = replace_linear_class(
                        child_module, style, self.quant_config, prefix=qual_name
                    )
666
667
668
669
670
671
672
673
674
                # TODO(hmellor): Enable RMSNorm replacement once we have a way
                # to choose RMSNorm vs GemmaRMSNorm
                # elif child_module.__class__.__name__.endswith("RMSNorm"):
                #     new_module = replace_rms_norm_class(
                #         child_module, self.config.hidden_size)
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
675
676
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
677

678
        _recursive_replace(self.model, prefix="model")
679

680
    def create_attention_instances(
681
        self, attn_type: AttentionType = AttentionType.DECODER
682
    ) -> dict[int, Attention]:
683
684
685
        """
        Create `Attention` instances to inform KV cache allocation.
        """
686
        num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
687
688
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
689
690
691
        start, end = get_pp_indices(
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
        )
692
693
694
695

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
696
            per_layer_sliding_window = None
697
698
699
700
            if (
                hasattr(self.config, "layer_types")
                and self.config.layer_types[i] == "sliding_attention"
            ):
701
                per_layer_sliding_window = self.config.sliding_window
702
703

            attention_instances[i] = Attention(
704
705
                num_heads=num_heads,
                head_size=head_size,
706
707
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
708
709
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
710
                cache_config=self.cache_config,
711
                quant_config=self.quant_config,
712
                per_layer_sliding_window=per_layer_sliding_window,
713
                prefix=f"{i}.attn",
714
715
                attn_type=attn_type,
            )
716
        return attention_instances
717

718
    def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None):
719
720
721
722
723
724
725
726
727
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
728
729
730
731
732
733
734
735
736

        def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
737
738
                        )
                    )
739
740
741
742
743
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)
744

745
746
    def forward(
        self,
747
        input_ids: Optional[torch.Tensor],
748
749
750
751
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
752
753
754
755
756
757
758
759
760
761
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

762
763
764
765
766
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

767
768
769
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
770
            use_cache=False,
771
            position_ids=position_ids,
772
            attention_instances=self.attention_instances,
773
774
            return_dict=False,
        )[0][0, ...]  # we remove batch dimension for now
775
776
777
778
779

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
780

781
782
783
784
    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
785
786
787
788
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
789
790
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
791
        )
792
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
793

794
795
796
797
798
799
    def check_version(self, min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers backend requires transformers>={required} "
800
801
                f"for {feature}, but got {installed}"
            )
802

803

804
@support_torch_compile(enable_if=can_enable_torch_compile)
805
class TransformersForCausalLM(TransformersBase):
806
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
807
        super().__init__(vllm_config=vllm_config, prefix=prefix)
808

809
810
811
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
812
            self.skip_prefixes.append("lm_head.")
813
814

        if get_pp_group().is_last_rank:
815
            self.unpadded_vocab_size = self.text_config.vocab_size
816
            self.lm_head = ParallelLMHead(
817
818
819
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
820
821
                prefix=maybe_prefix(prefix, "lm_head"),
            )
822
            if self.text_config.tie_word_embeddings:
823
                self.lm_head = self.lm_head.tie_weights(
824
825
                    self.model.get_input_embeddings()
                )
826

827
828
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
829
830
                self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
            )
831
832
833
        else:
            self.lm_head = PPMissingLayer()

834
835
836
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings()(input_ids)

837
838
839
840
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
841
        logits = self.logits_processor(self.lm_head, hidden_states)
842
843
        return logits

844

845
846
847
848
849
850
851
852
853
854
855
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
    """Flatten until a list of tensors can be concatenated then do concat"""

    def _can_concat(x: list[torch.Tensor]):
        return len(set(map(lambda _x: _x.shape[1:], x))) == 1

    if _can_concat(x):
        return torch.concat(x)
    return flatten_and_concat(flatten_bn(x))


856
857
858
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
859
860
    dummy_inputs=MultiModalDummyInputsBuilder,
)
861
@support_torch_compile(
862
    # set `positions` to last dim to support Qwen-mrope
863
864
865
866
867
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
868
    },
869
870
    enable_if=can_enable_torch_compile,
)
871
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
872
    merge_by_field_config = True
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
894
895
        }
    )
896

897
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
898
        super().__init__(vllm_config=vllm_config, prefix=prefix)
899
900
901
902
903
904
905
906
907
908
909

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
910
911
912
        model_output = super().forward(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
913
914
        return model_output

915
916
917
    def get_language_model(self) -> torch.nn.Module:
        return self.model

918
    def get_multimodal_embeddings(self, **kwargs):
919
920
921
922
923
        pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
        image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)
924
925
926
927

        if image_embeds is not None:
            return image_embeds

928
        if pixel_values is None:
929
930
931
932
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
933
            vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
934
935

            if isinstance(vision_embeddings, torch.Tensor):
936
937
938
                if isinstance(num_image_patches, list):
                    num_image_patches = torch.cat(num_image_patches)

939
940
941
942
943
944
945
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
946
947
                    vision_embeddings, num_image_patches.flatten().tolist()
                )
948
949
950
951
952
953
954
955
956
957
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
958
959
960
961
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
962
    ) -> torch.Tensor:
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        """
        Apply token embeddings to `input_ids`.

        If `multimodal_embeddings` is passed, scatter them into
        `input_ids` according to the mask `is_multimodal`.

        In case the multi-modal token IDs exceed the vocabulary size of
        the language model, you can set `handle_oov_mm_token=False`
        to avoid calling the language model's `get_input_embeddings` method
        on those tokens.
        """
        from .utils import _merge_multimodal_embeddings

        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.model.get_input_embeddings(),
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
990
991
                "https://github.com/vllm-project/vllm/pull/16229."
            )
992
993
994
995
996
997

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )