transformers.py 36.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18

19
from collections.abc import Iterable, Mapping
20
from contextlib import contextmanager
21
from pathlib import Path
22
from typing import Literal
23

24
import regex as re
25
import torch
26
27
import transformers
from packaging.version import Version
28
from torch import nn
29
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
30
31
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

32
from vllm.attention import Attention, AttentionType
33
from vllm.compilation.decorators import support_torch_compile
34
35
36
37
38
39
40
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    VllmConfig,
)
41
from vllm.config.multimodal import BaseDummyOptions
42
from vllm.config.utils import getattr_iter
43
from vllm.distributed import get_pp_group, get_tp_group
44
from vllm.distributed.utils import get_pp_indices
45
from vllm.logger import init_logger
46
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
47
48
49
50
51
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
53
from vllm.model_executor.layers.quantization import QuantizationConfig
54
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
59
60
61
62
63
64
65
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalUUIDDict,
    PlaceholderRange,
)
66
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
67
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
70
from vllm.sequence import IntermediateTensors

71
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant
72
73
74
75
76
77
78
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    make_empty_intermediate_tensors_factory,
    maybe_prefix,
)
79
80
81
82

logger = init_logger(__name__)


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


98
def vllm_flash_attention_forward(
99
100
101
102
103
104
105
    # Transformers args
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor,
    # Transformers kwargs
106
    scaling: float | None = None,
107
    # vLLM kwargs
108
    attention_instances: dict[Attention] | None = None,
109
110
    **kwargs,
):
111
112
113
114
115
116
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
117
    return self_attn.forward(query, key, value), None
118
119
120
121
122


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


123
124
125
126
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


144
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
145
146


147
def replace_linear_class(
148
    linear: nn.Linear,
149
    style: Style = "replicate",
150
    quant_config: QuantizationConfig | None = None,
151
152
    *,
    prefix: str = "",
153
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
154
    """
155
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
156

157
    Args:
158
159
160
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
161
    Returns:
162
        The new linear.
163
164
165
    """

    if not isinstance(style, str):
166
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
167

168
169
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
170
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
171
        "rowwise": (RowParallelLinear, {}),
172
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
173
174
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
175

176
    return vllm_linear_cls(
177
178
179
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
180
        quant_config=quant_config,
181
        prefix=prefix,
182
        return_bias=False,
183
        **vllm_linear_kwargs,
184
185
    )

186

187
188
189
190
191
192
193
194
195
196
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
    kwargs = {"hidden_size": hidden_size, "eps": eps}
    # Update hidden size if weight is available
    weight_meta = getattr(rms_norm, "weight", None)
    if weight_meta is not None:
        kwargs["hidden_size"] = weight_meta.size(0)
    # Check if weight is all zeros, which indicates GemmaRMSNorm
    # We must create a new instance because rms_norm is on meta
    try:
        with torch.device("cpu"):
            weight_test = getattr(rms_norm.__class__(1), "weight", None)
    except Exception:
        logger.warning(
            "Failed to determine if RMSNorm weight is centered on zero or one. "
            "Defaulting to one."
        )
        weight_test = None
    if weight_test is not None and torch.all(weight_test == 0):
        return GemmaRMSNorm(**kwargs)
    # Otherwise assume it's a regular RMSNorm
    kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
    if weight_meta is not None:
        kwargs["dtype"] = weight_meta.dtype
220
221
222
223
224
225
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
248
249
                module._parameters[name].to(device), **kwargs
            )
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
264
265
266
267
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
268
269
270
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
271
272
273
274
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
275
276
277
278
279
280
281
282
283
284
285
286
287
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
288
289
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
290
        mm_tokens = processor._get_num_multimodal_tokens(
291
292
            image_sizes=([height, width],), **mm_processor_kwargs
        )
293
294
295
296
297
298
299
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


300
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
315
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
316
317
318
319
320
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

321
322
        image_overrides = mm_options.get("image") if mm_options else None

323
        return {
324
325
326
327
328
329
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
330
331
332
333
334
335
336
337
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
338
        out_mm_kwargs: MultiModalKwargsItems,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
357
358
359
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
360
361
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
362
        num_image_patches = hf_inputs.get("num_image_patches")
363
        mm_fields = {
364
            key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
365
366
367
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
368
369
            "image", num_image_patches
        )
370
371
372
373

        # Keep these as batched, as they always have batch size as first dim
        mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
        mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
374
375
376
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

377
    def _get_hf_mm_data(
378
379
        self,
        mm_items: MultiModalDataItems,
380
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
381
        """
382
383
        In contrast to the base class, this method always adds
        `return_mm_token_type_ids` to the processor data
384
        """
385
        processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
386
        processor_data["return_mm_token_type_ids"] = True
387
        return processor_data, passthrough_data
388
389
390

    def apply(
        self,
391
        prompt: str | list[int],
392
393
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
394
395
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
396
397
398
399
400
401
402
403
404
405
406
407
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
408
409
410
411
412
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
413

414
415
416
417
418
419
420
421
422
423
        # Bypass cached processor and always apply to the full set of mm inputs
        # NOTE: we can't just set caching=False because base class method
        # transforms outputs to `MultiModalKwargs` which is not going to
        # work for Transformers. We have a lot of logic tied to
        # `mm_tokens_per_modality` below
        prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
            prompt_text=prompt,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
424
        )
425

426
427
428
429
430
431
432
433
434
435
        # For gemma3 we check `token_type_ids` as the key
        token_type_key = (
            "mm_token_type_ids"
            if "mm_token_type_ids" in processed_data
            else "token_type_ids"
        )
        mm_token_type_ids = processed_data.pop(token_type_key)

        # We can infer vLLM style placeholder from token type ids, if we split
        # it for each input `mm_data`.
436
437
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
438
439
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
440
441
442
443
444
445
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
446
447
            image_sizes=image_sizes, **mm_processor_kwargs
        )
448
449
450
451
452
453
454
455
456
457
458

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
459
460
461
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
                )
                for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
462
463
464
            ]
            mm_placeholders = {"image": ranges}

465
466
        processed_data["num_image_patches"] = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
467
        )
468
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
469
            processed_data,
470
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
471
        )
472

473
        # Use overrides if provided; fallback to data-dependent hashing.
474
475
476
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
477
478
479
480
481

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
482
            mm_hashes=mm_hashes,
483
484
485
486
            mm_placeholders=mm_placeholders,
        )


487
488
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
489
    embedding_modules = ["embed_tokens"]  # TODO transformers will have a util to get it
490

491
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
492
493
494
        super().__init__()
        logger.info("Using Transformers backend.")

495
496
497
498
499
500
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
501
        self.quant_config: QuantizationConfig | None = vllm_config.quant_config
502
503

        self.pp_group = get_pp_group()
504
        self.tp_group = get_tp_group()
505

506
507
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
508
        """Skip loading weights whose qualname starts with these prefixes."""
509
        self.skip_substrs: list[str] = []
510
511
512
513
514
515
516
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

517
518
519
520
        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
521
522
523
                raise NotImplementedError(
                    "Transformers backend does not support MXFP4 quantization yet."
                )
524
525
526
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")
527

528
529
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        self.text_config._attn_implementation = "vllm"
530
        with init_on_device_without_buffers("meta"):
531
            self.model: PreTrainedModel = AutoModel.from_config(
532
                self.config,
533
                dtype=self.model_config.dtype,
534
                trust_remote_code=self.model_config.trust_remote_code,
535
            )
536

537
        # Remove layers not on this pipeline parallel rank
538
        self.pipeline_parallel()
539
540
541
542
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()
543
544

        # Input embeddings
545
546
547
548
        input_embeddings = self.model.get_input_embeddings()
        if not isinstance(input_embeddings, PPMissingLayer):
            # Some models use embedding scales
            self.embed_scale = getattr(input_embeddings, "embed_scale", None)
549
550
551
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
552
553
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
554
                    self.text_config.vocab_size,
555
                    embedding_dim=embedding_dim,
556
557
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
558
559
                )
            )
560

561
        # Initialize any parameters that have not had their modules replaced
562
563
        self.init_parameters(self.model)

564
        # Pipeline parallel intermediate tensors
565
566
567
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.text_config.hidden_size
        )
568
569
570
571
572

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
573
        if self.pp_group.world_size <= 1:
574
            return
575

576
        if not self.model.supports_pp_plan:
577
578
579
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
580
            raise ValueError(
581
582
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
583
584
585
586
587
588
589
590
591
592
593
594

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
595
596
                "in the base model are not supported yet!"
            )
597
        if module_list_idx is None:
598
            raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
599
600
601

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
602
            if self.pp_group.is_first_rank or (
603
604
                self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
            ):
605
606
607
608
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
609
        start_layer, end_layer = get_pp_indices(
610
611
612
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
613
        )
614
615
616
617
618
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
619
            layers[i] = PPMissingLayer()
620
621

        # Layers after module list
622
        for name in pp_plan[module_list_idx + 1 :]:
623
624
625
626
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

627
628
629
630
631
632
633
    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
634
        """
635
        tp_plan = self.model.tp_plan
636

637
        if not tp_plan and self.tp_group.world_size > 1:
638
639
640
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
641
            raise ValueError(
642
643
                f"{type(self.model)} does not support tensor parallel. {tip}"
            )
644

645
646
647
648
        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
649
            for child_name, child_module in module.named_children():
650
                new_module = child_module
651
                qual_name = maybe_prefix(prefix, child_name)
652
653
654
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
655
656
657
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
658
                    style = tp_plan.get(pattern, "replicate")
659
660
661
                    new_module = replace_linear_class(
                        child_module, style, self.quant_config, prefix=qual_name
                    )
662
663
664
665
                elif child_module.__class__.__name__.endswith("RMSNorm"):
                    new_module = replace_rms_norm_class(
                        child_module, self.text_config.hidden_size
                    )
666
667
668
669
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
670
671
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
672

673
        _recursive_replace(self.model, prefix="model")
674

675
    def create_attention_instances(
676
        self, attn_type: AttentionType = AttentionType.DECODER
677
    ) -> dict[int, Attention]:
678
679
680
        """
        Create `Attention` instances to inform KV cache allocation.
        """
681
        num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
682
683
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
684
        logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
685
        start, end = get_pp_indices(
686
687
688
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
689
        )
690
691
692
693

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
694
            per_layer_sliding_window = None
695
696
697
698
            if (
                hasattr(self.config, "layer_types")
                and self.config.layer_types[i] == "sliding_attention"
            ):
699
                per_layer_sliding_window = self.config.sliding_window
700
701

            attention_instances[i] = Attention(
702
703
                num_heads=num_heads,
                head_size=head_size,
704
705
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
706
707
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
708
                cache_config=self.cache_config,
709
                quant_config=self.quant_config,
710
                logits_soft_cap=logits_soft_cap,
711
                per_layer_sliding_window=per_layer_sliding_window,
712
                prefix=f"{i}.attn",
713
714
                attn_type=attn_type,
            )
715
        return attention_instances
716

717
    def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
718
719
720
721
722
723
724
725
726
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
727

728
        def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
729
730
731
732
733
734
735
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
736
737
                        )
                    )
738
739
740
741
742
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)
743

744
745
    def forward(
        self,
746
        input_ids: torch.Tensor | None,
747
        positions: torch.Tensor,
748
749
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
750
        **kwargs,
751
    ) -> torch.Tensor | IntermediateTensors:
752
        if not self.pp_group.is_first_rank:
753
754
755
756
757
758
759
760
761
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

762
763
764
765
766
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

767
768
769
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
770
            use_cache=False,
771
            position_ids=position_ids,
772
            attention_instances=self.attention_instances,
773
            return_dict=False,
774
            **kwargs,
775
        )[0][0, ...]  # we remove batch dimension for now
776

777
        if not self.pp_group.is_last_rank:
778
779
780
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
781

782
783
784
785
    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
786
787
788
789
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
790
791
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
792
        )
793
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
794

795
796
797
798
799
800
    def check_version(self, min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers backend requires transformers>={required} "
801
802
                f"for {feature}, but got {installed}"
            )
803

804

805
@support_torch_compile(enable_if=can_enable_torch_compile)
806
class TransformersForCausalLM(TransformersBase):
807
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
808
        super().__init__(vllm_config=vllm_config, prefix=prefix)
809

810
811
812
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
813
            self.skip_prefixes.append("lm_head.")
814

815
        if self.pp_group.is_last_rank:
816
            self.unpadded_vocab_size = self.text_config.vocab_size
817
            self.lm_head = ParallelLMHead(
818
819
820
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
821
822
                prefix=maybe_prefix(prefix, "lm_head"),
            )
823
            if self.text_config.tie_word_embeddings:
824
                self.lm_head = self.lm_head.tie_weights(
825
826
                    self.model.get_input_embeddings()
                )
827

828
829
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
830
831
                self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
            )
832
833
834
        else:
            self.lm_head = PPMissingLayer()

835
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
836
837
838
839
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if self.embed_scale is not None:
            inputs_embeds *= self.embed_scale
        return inputs_embeds
840

841
842
843
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
844
    ) -> torch.Tensor | None:
845
        logits = self.logits_processor(self.lm_head, hidden_states)
846
847
        return logits

848
849
850
851

@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
852
853
    dummy_inputs=MultiModalDummyInputsBuilder,
)
854
@support_torch_compile(
855
    # set `positions` to last dim to support Qwen-mrope
856
857
858
859
860
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
861
    },
862
863
    enable_if=can_enable_torch_compile,
)
864
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
865
    supports_multimodal_raw_input_only = True
866
    merge_by_field_config = True
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
888
889
        }
    )
890

891
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
892
        super().__init__(vllm_config=vllm_config, prefix=prefix)
893
894
895
896
897

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
898
        input_ids: torch.Tensor | None,
899
        positions: torch.Tensor,
900
901
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
902
        **kwargs: object,
903
    ) -> torch.Tensor | IntermediateTensors:
904
905
906
        # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
        # Other models will not have `token_type_ids` in kwargs
        kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
907
        model_output = super().forward(
908
            input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
909
        )
910
911
        return model_output

912
    def get_language_model(self) -> torch.nn.Module:
913
914
915
916
917
918
919
920
921
922
923
924
        """`TransformersForMultimodalLM` does not contain a vLLM language model class.
        Therefore, in order to return a language model vLLM class, we use a wrapper to
        give `self` the same interface as `TransformersForCausalLM`."""

        class LanguageModelWrapper(TransformersForCausalLM):
            def __init__(self, multimodal_model):
                # Don't call super().__init__() to avoid re-initialization
                self.__dict__.update(multimodal_model.__dict__)

            model = getattr_iter(self.model, ("language_model", "text_model"), None)

        return LanguageModelWrapper(self)
925

926
    def get_multimodal_embeddings(self, **kwargs):
927
928
        pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
        image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
929
930
931
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)
932
933
934
935

        if image_embeds is not None:
            return image_embeds

936
        if pixel_values is None:
937
938
939
            return None

        num_image_patches = kwargs.pop("num_image_patches")
940
        kwargs.pop("token_type_ids", None)  # used only in `forward`
941
        if pixel_values is not None:
942
            vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
943
944
945
946
947
948
949
950
951

            if isinstance(vision_embeddings, torch.Tensor):
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
952
953
                    vision_embeddings, num_image_patches.flatten().tolist()
                )
954
955
956
957
958
959
960
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

961
    get_input_embeddings = SupportsMultiModal.get_input_embeddings