transformers.py 35.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18
from collections.abc import Iterable, Mapping
19
from contextlib import contextmanager
20
from pathlib import Path
21
from typing import Literal, Optional, Union
22

23
import regex as re
24
25
import torch
from torch import nn
26
27
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
                          PreTrainedModel)
28
29
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

30
from vllm.attention import Attention, AttentionType
31
from vllm.compilation.decorators import support_torch_compile
32
33
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
34
from vllm.config.utils import getattr_iter
35
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
37
38
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
39
                                               ReplicatedLinear,
40
41
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
45
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
46
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
47
48
                                    MultiModalInputs, MultiModalUUIDDict,
                                    PlaceholderRange)
49
50
51
52
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
53
54
from vllm.sequence import IntermediateTensors

55
56
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP, SupportsQuant)
57
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
58
59
                    flatten_bn, make_empty_intermediate_tensors_factory,
                    maybe_prefix)
60
61
62
63

logger = init_logger(__name__)


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


79
80
81
82
83
84
85
86
def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
87
        scaling: Optional[float] = None,
88
        # vLLM kwargs
89
        attention_instances: Optional[dict[Attention]] = None,
90
91
92
93
94
95
96
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
97
    return self_attn.forward(query, key, value), None
98
99
100
101
102


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


103
104
105
106
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


124
def replace_linear_class(
125
126
127
128
129
    linear: nn.Linear,
    style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig,
    *,
    prefix: str = "",
130
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
131
    """
132
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
133

134
135
136
137
138
139
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
140
141
142
143
144
145
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

146
147
148
149
150
151
152
153
154
155
156
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {
            "gather_output": True
        }),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {
            "input_is_parallel": False
        }),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
157

158
    return vllm_linear_cls(
159
160
161
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
162
        quant_config=quant_config,
163
        prefix=prefix,
164
        return_bias=False,
165
        **vllm_linear_kwargs,
166
167
    )

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs)

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):

        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch, torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)))
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for torch_function_name, old_torch_function in (
                tensor_constructors_to_patch.items()):
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):

    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
228
229
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        mm_tokens = processor._get_num_multimodal_tokens(
            image_sizes=([height, width], ), **mm_processor_kwargs)
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


class MultiModalDummyInputsBuilder(
        BaseDummyInputsBuilder[MultiModalProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
275
        out_mm_kwargs: MultiModalKwargsItems,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs,
        num_image_patches: torch.Tensor = None,
    ):
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        mm_fields = {
            key: MultiModalFieldConfig.flat_from_sizes("image",
                                                       num_image_patches)
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
            "image", num_image_patches)
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _apply_hf_processor_text_mm(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
316
    ) -> tuple[list[int], BatchFeature, bool]:
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        """
        Apply the HF processor on the prompt text and multi-modal data
        together.

        In addition, return whether prompt replacements have been applied.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
            tok_kwargs=tokenization_kwargs,
        )
        processed_data.update(passthrough_data)

        prompt_ids, = processed_data.pop("input_ids").tolist()
        mm_token_type_ids = processed_data.pop(
            "mm_token_type_ids"
        ) if "mm_token_type_ids" in processed_data else processed_data.pop(
            "token_type_ids")  # for gemma3 only

        return prompt_ids, processed_data, mm_token_type_ids

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
348
        mm_uuids: Optional[MultiModalUUIDDict] = None,
349
350
351
352
353
354
355
356
357
358
359
360
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
361
362
363
364
365
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
366
367
368
369
370
371
372
373
374
375
376
377
378
379

        (prompt_ids, processed_data,
         mm_token_type_ids) = self._apply_hf_processor_text_mm(
             prompt_text=prompt,
             mm_items=mm_items,
             hf_processor_mm_kwargs=hf_processor_mm_kwargs,
             tokenization_kwargs=tokenization_kwargs,
         )

        # HF processor will return `mm_token_type_ids` from which
        # we can infer mm_placeholders. Until then hardcode to make code run
        # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
380
381
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
            image_sizes=image_sizes, **mm_processor_kwargs)

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool())
                for positions, mm_tokens in zip(chunked_mm_positions,
                                                chunked_mm_tokens)
            ]
            mm_placeholders = {"image": ranges}

        num_image_patches = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
        ) if "num_image_patches" in mm_tokens_per_modality else None
        processed_data['num_image_patches'] = num_image_patches
410
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
411
412
413
414
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
                                       num_image_patches),
        )
415

416
        # Use overrides if provided; fallback to data-dependent hashing.
417
418
419
420
        mm_hashes = self._hash_mm_items(mm_items,
                                        hf_processor_mm_kwargs,
                                        tokenization_kwargs,
                                        mm_uuids=mm_uuids)
421
422
423
424
425
426

        return MultiModalInputs(
            type="multimodal",
            prompt=prompt,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
427
            mm_hashes=mm_hashes,
428
429
430
431
            mm_placeholders=mm_placeholders,
        )


432
433
434
435
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it
436

437
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
438
439
440
        super().__init__()
        logger.info("Using Transformers backend.")

441
442
443
444
445
446
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
447
448
        self.quant_config: Optional[
            QuantizationConfig] = vllm_config.quant_config
449
450
451
452
453
454

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

455
456
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
457
        """Skip loading weights whose qualname starts with these prefixes."""
458
        self.skip_substrs: list[str] = []
459
460
461
462
463
464
465
466
467
468
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

        # Skip loading extra bias for GPTQ models.
        if self.quant_config and "gptq" in self.quant_config.get_name():
            self.ignore_unexpected_suffixes.append(".bias")
469

470
471
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        # TODO: @raushan, use the public `model.set_attn_implementation()`
472
        # method once its checks are fixed in Transformers.
473
        self.text_config._attn_implementation = "vllm"
474
        with init_on_device_without_buffers("meta"):
475
            self.model: PreTrainedModel = AutoModel.from_config(
476
477
478
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
479
            )
480

481
482
483
484
485
        self.pipeline_parallel()
        self.tensor_parallel()

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
486
487
488
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
489
490
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
491
                    self.text_config.vocab_size,
492
                    embedding_dim=embedding_dim,
493
494
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
495
496
497
498
499
                ))

        # Attention layers
        self.attention_instances = self.create_attention_instances()

500
        # Initialize any parameters that have not had their modules replaced
501
502
        self.init_parameters(self.model)

503
        self.make_empty_intermediate_tensors = (
504
505
            make_empty_intermediate_tensors_factory(
                ["hidden_states"], self.text_config.hidden_size))
506
507
508
509
510
511
512

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
513

514
        if not self.model.supports_pp_plan:
515
516
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
517
            raise ValueError(
518
519
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
539
540
541
            if self.pp_group.is_first_rank or (
                    self.text_config.tie_word_embeddings
                    and self.pp_group.is_last_rank):
542
543
544
545
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
546
547
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size)
548
549
550
551
552
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
553
            layers[i] = PPMissingLayer()
554
555
556
557
558
559
560
561
562
563
564
565

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def tensor_parallel(self):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
566
567
568
569
570
        # Look for tp plans in all of the PreTrainedModels found in self.model
        is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
        supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
        pretrained_models = filter(is_pretrained_model, self.model.modules())
        models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
571

572
        if not any(models_with_tp_plan) and self.tp_size > 1:
573
574
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
575
            raise ValueError(
576
                f"{type(self.model)} does not support tensor parallel. {tip}")
577

578
        def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
579
580
581
582
583
584
585
586
587
588
589
590
591
592
            tp_plan = tp_plan or {}

            # If the current module is a PreTrainedModel, set the tp_plan for
            # all of its children
            if isinstance(module, PreTrainedModel):
                tp_plan = module.config.base_model_tp_plan or {}
                tp_plan = {
                    maybe_prefix(prefix, k): v
                    for k, v in tp_plan.items()
                }

            # Some weight loaders expect linear layers to inherit from vLLM's
            # LinearBase class, so we set a default style which causes any
            # unspecified linear layers to be replaced with ReplicatedLinear
593
594
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
595
596
597
598
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
                    style = tp_plan.get(pattern, "replicate")
599
600
601
602
                    new_module = replace_linear_class(child_module,
                                                      style,
                                                      self.quant_config,
                                                      prefix=qual_name)
603
604
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
605
                else:
606
607
608
                    _tensor_parallel(child_module,
                                     prefix=qual_name,
                                     tp_plan=tp_plan)
609

610
        _tensor_parallel(self.model, prefix="model")
611

612
613
614
615
    def create_attention_instances(
        self,
        attn_type: AttentionType = AttentionType.DECODER
    ) -> dict[int, Attention]:
616
617
618
619
620
621
622
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
623
        start, end = get_pp_indices(self.text_config.num_hidden_layers,
624
                                    self.pp_rank, self.pp_size)
625
626
627
628

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
629
630
631
632
            per_layer_sliding_window = None
            if (hasattr(self.config, "layer_types")
                    and self.config.layer_types[i] == "sliding_attention"):
                per_layer_sliding_window = self.config.sliding_window
633
634

            attention_instances[i] = Attention(
635
636
                num_heads=num_heads,
                head_size=head_size,
637
638
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
639
640
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
641
                cache_config=self.cache_config,
642
                quant_config=self.quant_config,
643
                per_layer_sliding_window=per_layer_sliding_window,
644
645
                prefix=f"{i}.attn",
                attn_type=attn_type)
646
        return attention_instances
647

648
649
650
    def init_parameters(self,
                        module: nn.Module,
                        dtype: Optional[torch.dtype] = None):
651
652
653
654
655
656
657
658
659
660
661
662
663
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(param.data,
664
                                     dtype=dtype or self.model_config.dtype,
665
666
667
                                     device=self.device_config.device))
                setattr(module, name, new_param)
        for child in module.children():
668
            self.init_parameters(child, dtype)
669

670
671
    def forward(
        self,
672
        input_ids: Optional[torch.Tensor],
673
674
675
676
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
677
678
679
680
681
682
683
684
685
686
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

687
688
689
690
691
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

692
693
694
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
695
            use_cache=False,
696
            position_ids=position_ids,
697
698
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
699
700
701
702
703

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
704

705
706
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
707
708
709
710
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
711
712
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
713
        )
714
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
715
716


717
@support_torch_compile(enable_if=can_enable_torch_compile)
718
class TransformersForCausalLM(TransformersBase):
719
720

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
721
        super().__init__(vllm_config=vllm_config, prefix=prefix)
722

723
724
725
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
726
            self.skip_prefixes.append("lm_head.")
727
728

        if get_pp_group().is_last_rank:
729
            self.unpadded_vocab_size = self.text_config.vocab_size
730
            self.lm_head = ParallelLMHead(
731
732
733
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
734
735
                prefix=maybe_prefix(prefix, "lm_head"),
            )
736
            if self.text_config.tie_word_embeddings:
737
738
739
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

740
741
742
743
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size,
                logit_scale)
744
745
746
        else:
            self.lm_head = PPMissingLayer()

747
748
749
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings()(input_ids)

750
751
752
753
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
754
        logits = self.logits_processor(self.lm_head, hidden_states)
755
756
        return logits

757

758
759
760
761
762
763
764
765
766
767
768
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
    """Flatten until a list of tensors can be concatenated then do concat"""

    def _can_concat(x: list[torch.Tensor]):
        return len(set(map(lambda _x: _x.shape[1:], x))) == 1

    if _can_concat(x):
        return torch.concat(x)
    return flatten_and_concat(flatten_bn(x))


769
770
771
772
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
    dummy_inputs=MultiModalDummyInputsBuilder)
773
@support_torch_compile(
774
    # set `positions` to last dim to support Qwen-mrope
775
776
777
778
779
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
780
781
    },
    enable_if=can_enable_torch_compile)
782
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
783
    merge_by_field_config = True
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
        })

807
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
808
        super().__init__(vllm_config=vllm_config, prefix=prefix)
809
810
811
812
813
814
815
816
817
818
819

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
820
821
        model_output = super().forward(input_ids, positions,
                                       intermediate_tensors, inputs_embeds)
822
823
        return model_output

824
825
826
    def get_language_model(self) -> torch.nn.Module:
        return self.model

827
    def get_multimodal_embeddings(self, **kwargs):
828
829
830
831
832
        pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
        image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)
833
834
835
836

        if image_embeds is not None:
            return image_embeds

837
        if pixel_values is None:
838
839
840
841
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
842
            vision_embeddings = self.model.get_image_features(
843
                pixel_values, **kwargs)
844
845

            if isinstance(vision_embeddings, torch.Tensor):
846
847
848
                if isinstance(num_image_patches, list):
                    num_image_patches = torch.cat(num_image_patches)

849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
                    vision_embeddings,
                    num_image_patches.flatten().tolist())
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
868
869
870
871
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
872
    ) -> torch.Tensor:
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
        """
        Apply token embeddings to `input_ids`.

        If `multimodal_embeddings` is passed, scatter them into
        `input_ids` according to the mask `is_multimodal`.

        In case the multi-modal token IDs exceed the vocabulary size of
        the language model, you can set `handle_oov_mm_token=False`
        to avoid calling the language model's `get_input_embeddings` method
        on those tokens.
        """
        from .utils import _merge_multimodal_embeddings

        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.model.get_input_embeddings(),
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
                "https://github.com/vllm-project/vllm/pull/16229.")

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )