transformers.py 38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18
from collections.abc import Iterable, Mapping
19
from contextlib import contextmanager
20
from pathlib import Path
21
from typing import Literal, Optional, Union
22

23
import regex as re
24
import torch
25
26
import transformers
from packaging.version import Version
27
from torch import nn
28
29
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
                          PreTrainedModel)
30
31
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

32
from vllm.attention import Attention, AttentionType
33
from vllm.compilation.decorators import support_torch_compile
34
35
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
36
from vllm.config.multimodal import BaseDummyOptions
37
from vllm.config.utils import getattr_iter
38
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
40
from vllm.logger import init_logger
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
43
                                               ReplicatedLinear,
44
45
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
47
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
49
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
50
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
51
52
                                    MultiModalInputs, MultiModalUUIDDict,
                                    PlaceholderRange)
53
54
55
56
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
57
58
from vllm.sequence import IntermediateTensors

59
60
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP, SupportsQuant)
61
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
62
63
                    flatten_bn, make_empty_intermediate_tensors_factory,
                    maybe_prefix)
64
65
66
67

logger = init_logger(__name__)


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


83
84
85
86
87
88
89
90
def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
91
        scaling: Optional[float] = None,
92
        # vLLM kwargs
93
        attention_instances: Optional[dict[Attention]] = None,
94
95
96
97
98
99
100
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
101
    return self_attn.forward(query, key, value), None
102
103
104
105
106


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


107
108
109
110
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


128
129
130
131
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep",
                "replicate"]


132
def replace_linear_class(
133
    linear: nn.Linear,
134
135
    style: Style = "replicate",
    quant_config: Optional[QuantizationConfig] = None,
136
137
    *,
    prefix: str = "",
138
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
139
    """
140
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
141

142
    Args:
143
144
145
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
146
    Returns:
147
        The new linear.
148
149
150
151
152
153
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

154
155
156
157
158
159
160
161
162
163
164
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {
            "gather_output": True
        }),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {
            "input_is_parallel": False
        }),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
165

166
    return vllm_linear_cls(
167
168
169
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
170
        quant_config=quant_config,
171
        prefix=prefix,
172
        return_bias=False,
173
        **vllm_linear_kwargs,
174
175
    )

176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    kwargs = {
        "hidden_size": hidden_size,
        "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
        "has_weight": getattr(rms_norm, "with_scale", True)
    }
    if (weight := getattr(rms_norm, "weight", None)) is not None:
        # If weight is a Parameter, get its data tensor
        weight = getattr(weight, "data", weight)
        kwargs["dtype"] = weight.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs)

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):

        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch, torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)))
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for torch_function_name, old_torch_function in (
                tensor_constructors_to_patch.items()):
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):

    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
261
262
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        mm_tokens = processor._get_num_multimodal_tokens(
            image_sizes=([height, width], ), **mm_processor_kwargs)
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


class MultiModalDummyInputsBuilder(
        BaseDummyInputsBuilder[MultiModalProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
289
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
290
291
292
293
294
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

295
296
        image_overrides = mm_options.get("image") if mm_options else None

297
298
299
300
        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
301
302
                                   num_images=num_images,
                                   overrides=image_overrides),
303
304
305
306
307
308
309
310
311
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
312
        out_mm_kwargs: MultiModalKwargsItems,
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs,
        num_image_patches: torch.Tensor = None,
    ):
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        mm_fields = {
            key: MultiModalFieldConfig.flat_from_sizes("image",
                                                       num_image_patches)
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
            "image", num_image_patches)
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _apply_hf_processor_text_mm(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
353
    ) -> tuple[list[int], BatchFeature, bool]:
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        """
        Apply the HF processor on the prompt text and multi-modal data
        together.

        In addition, return whether prompt replacements have been applied.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
            tok_kwargs=tokenization_kwargs,
        )
        processed_data.update(passthrough_data)

        prompt_ids, = processed_data.pop("input_ids").tolist()
        mm_token_type_ids = processed_data.pop(
            "mm_token_type_ids"
        ) if "mm_token_type_ids" in processed_data else processed_data.pop(
            "token_type_ids")  # for gemma3 only

        return prompt_ids, processed_data, mm_token_type_ids

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
385
        mm_uuids: Optional[MultiModalUUIDDict] = None,
386
387
388
389
390
391
392
393
394
395
396
397
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
398
399
400
401
402
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
403
404
405
406
407
408
409
410
411
412
413
414
415
416

        (prompt_ids, processed_data,
         mm_token_type_ids) = self._apply_hf_processor_text_mm(
             prompt_text=prompt,
             mm_items=mm_items,
             hf_processor_mm_kwargs=hf_processor_mm_kwargs,
             tokenization_kwargs=tokenization_kwargs,
         )

        # HF processor will return `mm_token_type_ids` from which
        # we can infer mm_placeholders. Until then hardcode to make code run
        # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
417
418
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
            image_sizes=image_sizes, **mm_processor_kwargs)

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool())
                for positions, mm_tokens in zip(chunked_mm_positions,
                                                chunked_mm_tokens)
            ]
            mm_placeholders = {"image": ranges}

        num_image_patches = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
        ) if "num_image_patches" in mm_tokens_per_modality else None
        processed_data['num_image_patches'] = num_image_patches
447
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
448
449
450
451
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
                                       num_image_patches),
        )
452

453
        # Use overrides if provided; fallback to data-dependent hashing.
454
455
456
457
        mm_hashes = self._hash_mm_items(mm_items,
                                        hf_processor_mm_kwargs,
                                        tokenization_kwargs,
                                        mm_uuids=mm_uuids)
458
459
460
461
462

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
463
            mm_hashes=mm_hashes,
464
465
466
467
            mm_placeholders=mm_placeholders,
        )


468
469
470
471
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it
472

473
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
474
475
476
        super().__init__()
        logger.info("Using Transformers backend.")

477
478
479
480
481
482
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
483
484
        self.quant_config: Optional[
            QuantizationConfig] = vllm_config.quant_config
485
486
487
488
489
490

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

491
492
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
493
        """Skip loading weights whose qualname starts with these prefixes."""
494
        self.skip_substrs: list[str] = []
495
496
497
498
499
500
501
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

502
503
504
505
506
507
508
509
510
        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
                raise NotImplementedError("Transformers backend does not "
                                          "support MXFP4 quantization yet.")
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")
511

512
513
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        # TODO: @raushan, use the public `model.set_attn_implementation()`
514
        # method once its checks are fixed in Transformers.
515
        self.text_config._attn_implementation = "vllm"
516
        with init_on_device_without_buffers("meta"):
517
            self.model: PreTrainedModel = AutoModel.from_config(
518
519
520
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
521
            )
522

523
        # Remove layers not on this pipeline parallel rank
524
        self.pipeline_parallel()
525
526
527
528
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()
529
530
531

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
532
533
534
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
535
536
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
537
                    self.text_config.vocab_size,
538
                    embedding_dim=embedding_dim,
539
540
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
541
542
                ))

543
        # Initialize any parameters that have not had their modules replaced
544
545
        self.init_parameters(self.model)

546
        # Pipeline parallel intermediate tensors
547
        self.make_empty_intermediate_tensors = (
548
549
            make_empty_intermediate_tensors_factory(
                ["hidden_states"], self.text_config.hidden_size))
550
551
552
553
554
555
556

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
557

558
        if not self.model.supports_pp_plan:
559
560
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
561
            raise ValueError(
562
563
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
583
584
585
            if self.pp_group.is_first_rank or (
                    self.text_config.tie_word_embeddings
                    and self.pp_group.is_last_rank):
586
587
588
589
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
590
591
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size)
592
593
594
595
596
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
597
            layers[i] = PPMissingLayer()
598
599
600
601
602
603
604

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

605
606
607
608
609
610
611
    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
612
        """
613
        tp_plan = self.model.tp_plan
614

615
        if not tp_plan and self.tp_size > 1:
616
617
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
618
            raise ValueError(
619
                f"{type(self.model)} does not support tensor parallel. {tip}")
620

621
622
623
624
        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
625
            for child_name, child_module in module.named_children():
626
                new_module = child_module
627
                qual_name = maybe_prefix(prefix, child_name)
628
629
630
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
631
632
633
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
634
                    style = tp_plan.get(pattern, "replicate")
635
636
637
638
                    new_module = replace_linear_class(child_module,
                                                      style,
                                                      self.quant_config,
                                                      prefix=qual_name)
639
640
641
642
643
644
645
646
647
                # TODO(hmellor): Enable RMSNorm replacement once we have a way
                # to choose RMSNorm vs GemmaRMSNorm
                # elif child_module.__class__.__name__.endswith("RMSNorm"):
                #     new_module = replace_rms_norm_class(
                #         child_module, self.config.hidden_size)
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
648
649
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
650

651
        _recursive_replace(self.model, prefix="model")
652

653
654
655
656
    def create_attention_instances(
        self,
        attn_type: AttentionType = AttentionType.DECODER
    ) -> dict[int, Attention]:
657
658
659
660
661
662
663
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
664
        start, end = get_pp_indices(self.text_config.num_hidden_layers,
665
                                    self.pp_rank, self.pp_size)
666
667
668
669

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
670
671
672
673
            per_layer_sliding_window = None
            if (hasattr(self.config, "layer_types")
                    and self.config.layer_types[i] == "sliding_attention"):
                per_layer_sliding_window = self.config.sliding_window
674
675

            attention_instances[i] = Attention(
676
677
                num_heads=num_heads,
                head_size=head_size,
678
679
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
680
681
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
682
                cache_config=self.cache_config,
683
                quant_config=self.quant_config,
684
                per_layer_sliding_window=per_layer_sliding_window,
685
686
                prefix=f"{i}.attn",
                attn_type=attn_type)
687
        return attention_instances
688

689
690
691
    def init_parameters(self,
                        module: nn.Module,
                        dtype: Optional[torch.dtype] = None):
692
693
694
695
696
697
698
699
700
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

        def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
                        ))
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)
716

717
718
    def forward(
        self,
719
        input_ids: Optional[torch.Tensor],
720
721
722
723
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
724
725
726
727
728
729
730
731
732
733
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

734
735
736
737
738
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

739
740
741
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
742
            use_cache=False,
743
            position_ids=position_ids,
744
745
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
746
747
748
749
750

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
751

752
753
754
755
    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
756
757
758
759
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
760
761
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
762
        )
763
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
764

765
766
767
768
769
770
771
772
    def check_version(self, min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers backend requires transformers>={required} "
                f"for {feature}, but got {installed}")

773

774
@support_torch_compile(enable_if=can_enable_torch_compile)
775
class TransformersForCausalLM(TransformersBase):
776
777

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
778
        super().__init__(vllm_config=vllm_config, prefix=prefix)
779

780
781
782
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
783
            self.skip_prefixes.append("lm_head.")
784
785

        if get_pp_group().is_last_rank:
786
            self.unpadded_vocab_size = self.text_config.vocab_size
787
            self.lm_head = ParallelLMHead(
788
789
790
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
791
792
                prefix=maybe_prefix(prefix, "lm_head"),
            )
793
            if self.text_config.tie_word_embeddings:
794
795
796
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

797
798
799
800
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size,
                logit_scale)
801
802
803
        else:
            self.lm_head = PPMissingLayer()

804
805
806
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings()(input_ids)

807
808
809
810
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
811
        logits = self.logits_processor(self.lm_head, hidden_states)
812
813
        return logits

814

815
816
817
818
819
820
821
822
823
824
825
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
    """Flatten until a list of tensors can be concatenated then do concat"""

    def _can_concat(x: list[torch.Tensor]):
        return len(set(map(lambda _x: _x.shape[1:], x))) == 1

    if _can_concat(x):
        return torch.concat(x)
    return flatten_and_concat(flatten_bn(x))


826
827
828
829
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
    dummy_inputs=MultiModalDummyInputsBuilder)
830
@support_torch_compile(
831
    # set `positions` to last dim to support Qwen-mrope
832
833
834
835
836
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
837
838
    },
    enable_if=can_enable_torch_compile)
839
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
840
    merge_by_field_config = True
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
        })

864
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
865
        super().__init__(vllm_config=vllm_config, prefix=prefix)
866
867
868
869
870
871
872
873
874
875
876

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
877
878
        model_output = super().forward(input_ids, positions,
                                       intermediate_tensors, inputs_embeds)
879
880
        return model_output

881
882
883
    def get_language_model(self) -> torch.nn.Module:
        return self.model

884
    def get_multimodal_embeddings(self, **kwargs):
885
886
887
888
889
        pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
        image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)
890
891
892
893

        if image_embeds is not None:
            return image_embeds

894
        if pixel_values is None:
895
896
897
898
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
899
            vision_embeddings = self.model.get_image_features(
900
                pixel_values, **kwargs)
901
902

            if isinstance(vision_embeddings, torch.Tensor):
903
904
905
                if isinstance(num_image_patches, list):
                    num_image_patches = torch.cat(num_image_patches)

906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
                    vision_embeddings,
                    num_image_patches.flatten().tolist())
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
925
926
927
928
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
929
    ) -> torch.Tensor:
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        """
        Apply token embeddings to `input_ids`.

        If `multimodal_embeddings` is passed, scatter them into
        `input_ids` according to the mask `is_multimodal`.

        In case the multi-modal token IDs exceed the vocabulary size of
        the language model, you can set `handle_oov_mm_token=False`
        to avoid calling the language model's `get_input_embeddings` method
        on those tokens.
        """
        from .utils import _merge_multimodal_embeddings

        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.model.get_input_embeddings(),
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
                "https://github.com/vllm-project/vllm/pull/16229.")

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )