transformers.py 36.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18

19
from collections.abc import Iterable, Mapping
20
from contextlib import contextmanager
21
from pathlib import Path
22
from typing import Literal, Optional, Union
23

24
import regex as re
25
import torch
26
27
import transformers
from packaging.version import Version
28
from torch import nn
29
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
30
31
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

32
from vllm.attention import Attention, AttentionType
33
from vllm.compilation.decorators import support_torch_compile
34
35
36
37
38
39
40
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    VllmConfig,
)
41
from vllm.config.multimodal import BaseDummyOptions
42
from vllm.config.utils import getattr_iter
43
from vllm.distributed import get_pp_group, get_tp_group
44
from vllm.distributed.utils import get_pp_indices
45
from vllm.logger import init_logger
46
from vllm.model_executor.layers.layernorm import RMSNorm
47
48
49
50
51
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
52
from vllm.model_executor.layers.logits_processor import LogitsProcessor
53
from vllm.model_executor.layers.quantization import QuantizationConfig
54
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
59
60
61
62
63
64
65
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalUUIDDict,
    PlaceholderRange,
)
66
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
67
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
70
from vllm.sequence import IntermediateTensors

71
72
73
74
75
76
77
78
79
80
81
82
83
84
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    make_empty_intermediate_tensors_factory,
    maybe_prefix,
)
85
86
87
88

logger = init_logger(__name__)


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


104
def vllm_flash_attention_forward(
105
106
107
108
109
110
111
112
113
114
115
116
    # Transformers args
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor,
    # Transformers kwargs
    scaling: Optional[float] = None,
    # vLLM kwargs
    attention_instances: Optional[dict[Attention]] = None,
    **kwargs,
):
117
118
119
120
121
122
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
123
    return self_attn.forward(query, key, value), None
124
125
126
127
128


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


129
130
131
132
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


150
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
151
152


153
def replace_linear_class(
154
    linear: nn.Linear,
155
156
    style: Style = "replicate",
    quant_config: Optional[QuantizationConfig] = None,
157
158
    *,
    prefix: str = "",
159
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
160
    """
161
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
162

163
    Args:
164
165
166
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
167
    Returns:
168
        The new linear.
169
170
171
    """

    if not isinstance(style, str):
172
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
173

174
175
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
176
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
177
        "rowwise": (RowParallelLinear, {}),
178
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
179
180
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
181

182
    return vllm_linear_cls(
183
184
185
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
186
        quant_config=quant_config,
187
        prefix=prefix,
188
        return_bias=False,
189
        **vllm_linear_kwargs,
190
191
    )

192

193
194
195
196
197
198
199
200
201
202
203
204
205
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    kwargs = {
        "hidden_size": hidden_size,
        "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
206
        "has_weight": getattr(rms_norm, "with_scale", True),
207
208
209
210
211
212
213
214
215
216
217
    }
    if (weight := getattr(rms_norm, "weight", None)) is not None:
        # If weight is a Parameter, get its data tensor
        weight = getattr(weight, "data", weight)
        kwargs["dtype"] = weight.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)


218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
240
241
                module._parameters[name].to(device), **kwargs
            )
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
256
257
258
259
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
260
261
262
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
263
264
265
266
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
267
268
269
270
271
272
273
274
275
276
277
278
279
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
280
281
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
282
        mm_tokens = processor._get_num_multimodal_tokens(
283
284
            image_sizes=([height, width],), **mm_processor_kwargs
        )
285
286
287
288
289
290
291
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


292
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
307
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
308
309
310
311
312
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

313
314
        image_overrides = mm_options.get("image") if mm_options else None

315
        return {
316
317
318
319
320
321
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
322
323
324
325
326
327
328
329
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
330
        out_mm_kwargs: MultiModalKwargsItems,
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
349
350
351
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
352
353
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
354
        num_image_patches = hf_inputs.get("num_image_patches")
355
        mm_fields = {
356
            key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
357
358
359
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
360
361
            "image", num_image_patches
        )
362
363
364
365

        # Keep these as batched, as they always have batch size as first dim
        mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
        mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
366
367
368
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

369
    def _get_hf_mm_data(
370
371
        self,
        mm_items: MultiModalDataItems,
372
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
373
        """
374
375
        In contrast to the base class, this method always adds
        `return_mm_token_type_ids` to the processor data
376
        """
377
        processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
378
        processor_data["return_mm_token_type_ids"] = True
379
        return processor_data, passthrough_data
380
381
382
383
384
385
386

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
387
        mm_uuids: Optional[MultiModalUUIDDict] = None,
388
389
390
391
392
393
394
395
396
397
398
399
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
400
401
402
403
404
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
405

406
407
408
409
410
411
412
413
414
415
        # Bypass cached processor and always apply to the full set of mm inputs
        # NOTE: we can't just set caching=False because base class method
        # transforms outputs to `MultiModalKwargs` which is not going to
        # work for Transformers. We have a lot of logic tied to
        # `mm_tokens_per_modality` below
        prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
            prompt_text=prompt,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
416
        )
417

418
419
420
421
422
423
424
425
426
427
        # For gemma3 we check `token_type_ids` as the key
        token_type_key = (
            "mm_token_type_ids"
            if "mm_token_type_ids" in processed_data
            else "token_type_ids"
        )
        mm_token_type_ids = processed_data.pop(token_type_key)

        # We can infer vLLM style placeholder from token type ids, if we split
        # it for each input `mm_data`.
428
429
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
430
431
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
432
433
434
435
436
437
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
438
439
            image_sizes=image_sizes, **mm_processor_kwargs
        )
440
441
442
443
444
445
446
447
448
449
450

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
451
452
453
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
                )
                for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
454
455
456
            ]
            mm_placeholders = {"image": ranges}

457
458
        processed_data["num_image_patches"] = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
459
        )
460
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
461
            processed_data,
462
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
463
        )
464

465
        # Use overrides if provided; fallback to data-dependent hashing.
466
467
468
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
469
470
471
472
473

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
474
            mm_hashes=mm_hashes,
475
476
477
478
            mm_placeholders=mm_placeholders,
        )


479
480
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
481
    embedding_modules = ["embed_tokens"]  # TODO transformers will have a util to get it
482

483
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
484
485
486
        super().__init__()
        logger.info("Using Transformers backend.")

487
488
489
490
491
492
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
493
        self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
494
495

        self.pp_group = get_pp_group()
496
        self.tp_group = get_tp_group()
497

498
499
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
500
        """Skip loading weights whose qualname starts with these prefixes."""
501
        self.skip_substrs: list[str] = []
502
503
504
505
506
507
508
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

509
510
511
512
        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
513
514
515
                raise NotImplementedError(
                    "Transformers backend does not support MXFP4 quantization yet."
                )
516
517
518
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")
519

520
521
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        self.text_config._attn_implementation = "vllm"
522
        with init_on_device_without_buffers("meta"):
523
            self.model: PreTrainedModel = AutoModel.from_config(
524
525
526
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
527
            )
528

529
        # Remove layers not on this pipeline parallel rank
530
        self.pipeline_parallel()
531
532
533
534
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()
535
536
537

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
538
539
540
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
541
542
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
543
                    self.text_config.vocab_size,
544
                    embedding_dim=embedding_dim,
545
546
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
547
548
                )
            )
549

550
        # Initialize any parameters that have not had their modules replaced
551
552
        self.init_parameters(self.model)

553
        # Pipeline parallel intermediate tensors
554
555
556
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.text_config.hidden_size
        )
557
558
559
560
561

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
562
        if self.pp_group.world_size <= 1:
563
            return
564

565
        if not self.model.supports_pp_plan:
566
567
568
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
569
            raise ValueError(
570
571
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
572
573
574
575
576
577
578
579
580
581
582
583

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
584
585
                "in the base model are not supported yet!"
            )
586
        if module_list_idx is None:
587
            raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
588
589
590

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
591
            if self.pp_group.is_first_rank or (
592
593
                self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
            ):
594
595
596
597
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
598
        start_layer, end_layer = get_pp_indices(
599
600
601
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
602
        )
603
604
605
606
607
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
608
            layers[i] = PPMissingLayer()
609
610

        # Layers after module list
611
        for name in pp_plan[module_list_idx + 1 :]:
612
613
614
615
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

616
617
618
619
620
621
622
    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
623
        """
624
        tp_plan = self.model.tp_plan
625

626
        if not tp_plan and self.tp_group.world_size > 1:
627
628
629
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
630
            raise ValueError(
631
632
                f"{type(self.model)} does not support tensor parallel. {tip}"
            )
633

634
635
636
637
        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
638
            for child_name, child_module in module.named_children():
639
                new_module = child_module
640
                qual_name = maybe_prefix(prefix, child_name)
641
642
643
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
644
645
646
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
647
                    style = tp_plan.get(pattern, "replicate")
648
649
650
                    new_module = replace_linear_class(
                        child_module, style, self.quant_config, prefix=qual_name
                    )
651
652
653
654
655
656
657
658
659
                # TODO(hmellor): Enable RMSNorm replacement once we have a way
                # to choose RMSNorm vs GemmaRMSNorm
                # elif child_module.__class__.__name__.endswith("RMSNorm"):
                #     new_module = replace_rms_norm_class(
                #         child_module, self.config.hidden_size)
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
660
661
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
662

663
        _recursive_replace(self.model, prefix="model")
664

665
    def create_attention_instances(
666
        self, attn_type: AttentionType = AttentionType.DECODER
667
    ) -> dict[int, Attention]:
668
669
670
        """
        Create `Attention` instances to inform KV cache allocation.
        """
671
        num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
672
673
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
674
        start, end = get_pp_indices(
675
676
677
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
678
        )
679
680
681
682

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
683
            per_layer_sliding_window = None
684
685
686
687
            if (
                hasattr(self.config, "layer_types")
                and self.config.layer_types[i] == "sliding_attention"
            ):
688
                per_layer_sliding_window = self.config.sliding_window
689
690

            attention_instances[i] = Attention(
691
692
                num_heads=num_heads,
                head_size=head_size,
693
694
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
695
696
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
697
                cache_config=self.cache_config,
698
                quant_config=self.quant_config,
699
                per_layer_sliding_window=per_layer_sliding_window,
700
                prefix=f"{i}.attn",
701
702
                attn_type=attn_type,
            )
703
        return attention_instances
704

705
    def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None):
706
707
708
709
710
711
712
713
714
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
715
716
717
718
719
720
721
722
723

        def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
724
725
                        )
                    )
726
727
728
729
730
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)
731

732
733
    def forward(
        self,
734
        input_ids: Optional[torch.Tensor],
735
736
737
738
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
739
        if not self.pp_group.is_first_rank:
740
741
742
743
744
745
746
747
748
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

749
750
751
752
753
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

754
755
756
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
757
            use_cache=False,
758
            position_ids=position_ids,
759
            attention_instances=self.attention_instances,
760
761
            return_dict=False,
        )[0][0, ...]  # we remove batch dimension for now
762

763
        if not self.pp_group.is_last_rank:
764
765
766
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
767

768
769
770
771
    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
772
773
774
775
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
776
777
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
778
        )
779
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
780

781
782
783
784
785
786
    def check_version(self, min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers backend requires transformers>={required} "
787
788
                f"for {feature}, but got {installed}"
            )
789

790

791
@support_torch_compile(enable_if=can_enable_torch_compile)
792
class TransformersForCausalLM(TransformersBase):
793
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
794
        super().__init__(vllm_config=vllm_config, prefix=prefix)
795

796
797
798
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
799
            self.skip_prefixes.append("lm_head.")
800

801
        if self.pp_group.is_last_rank:
802
            self.unpadded_vocab_size = self.text_config.vocab_size
803
            self.lm_head = ParallelLMHead(
804
805
806
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
807
808
                prefix=maybe_prefix(prefix, "lm_head"),
            )
809
            if self.text_config.tie_word_embeddings:
810
                self.lm_head = self.lm_head.tie_weights(
811
812
                    self.model.get_input_embeddings()
                )
813

814
815
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
816
817
                self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
            )
818
819
820
        else:
            self.lm_head = PPMissingLayer()

821
822
823
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings()(input_ids)

824
825
826
827
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
828
        logits = self.logits_processor(self.lm_head, hidden_states)
829
830
        return logits

831
832
833
834

@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
835
836
    dummy_inputs=MultiModalDummyInputsBuilder,
)
837
@support_torch_compile(
838
    # set `positions` to last dim to support Qwen-mrope
839
840
841
842
843
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
844
    },
845
846
    enable_if=can_enable_torch_compile,
)
847
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
848
    merge_by_field_config = True
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
870
871
        }
    )
872

873
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
874
        super().__init__(vllm_config=vllm_config, prefix=prefix)
875
876
877
878
879
880
881
882
883
884
885

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
886
887
888
        model_output = super().forward(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
889
890
        return model_output

891
892
893
    def get_language_model(self) -> torch.nn.Module:
        return self.model

894
    def get_multimodal_embeddings(self, **kwargs):
895
896
897
898
899
        pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
        image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)
900
901
902
903

        if image_embeds is not None:
            return image_embeds

904
        if pixel_values is None:
905
906
907
908
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
909
            vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
910
911
912
913
914
915
916
917
918

            if isinstance(vision_embeddings, torch.Tensor):
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
919
920
                    vision_embeddings, num_image_patches.flatten().tolist()
                )
921
922
923
924
925
926
927
928
929
930
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
931
932
933
934
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
935
    ) -> torch.Tensor:
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        """
        Apply token embeddings to `input_ids`.

        If `multimodal_embeddings` is passed, scatter them into
        `input_ids` according to the mask `is_multimodal`.

        In case the multi-modal token IDs exceed the vocabulary size of
        the language model, you can set `handle_oov_mm_token=False`
        to avoid calling the language model's `get_input_embeddings` method
        on those tokens.
        """
        from .utils import _merge_multimodal_embeddings

        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.model.get_input_embeddings(),
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
963
964
                "https://github.com/vllm-project/vllm/pull/16229."
            )
965
966
967
968
969
970

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )