transformers.py 37.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18
from collections.abc import Iterable, Mapping
19
from contextlib import contextmanager
20
from pathlib import Path
21
from typing import Literal, Optional, Union
22

23
import regex as re
24
25
import torch
from torch import nn
26
27
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
                          PreTrainedModel)
28
29
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

30
from vllm.attention import Attention, AttentionType
31
from vllm.compilation.decorators import support_torch_compile
32
33
34
35
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
36
37
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
38
                                               ReplicatedLinear,
39
40
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.quantization import QuantizationConfig
42
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
45
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
46
47
                                    MultiModalInputs, MultiModalUUIDDict,
                                    PlaceholderRange)
48
49
50
51
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
52
from vllm.sequence import IntermediateTensors
53
from vllm.utils import is_list_of
54

55
56
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
                         SupportsQuant)
57
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
58
59
                    flatten_bn, make_empty_intermediate_tensors_factory,
                    maybe_prefix)
60
61
62
63

logger = init_logger(__name__)


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


79
80
81
82
83
84
85
86
def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
87
        scaling: Optional[float] = None,
88
        # vLLM kwargs
89
        attention_instances: Optional[dict[Attention]] = None,
90
91
92
93
94
95
96
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
97
    return self_attn.forward(query, key, value), None
98
99
100
101
102


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


103
104
105
106
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable


124
def replace_linear_class(
125
126
127
128
129
    linear: nn.Linear,
    style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig,
    *,
    prefix: str = "",
130
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
131
    """
132
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
133

134
135
136
137
138
139
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
140
141
142
143
144
145
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

146
147
148
149
150
151
152
153
154
155
156
    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {
            "gather_output": True
        }),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {
            "input_is_parallel": False
        }),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))
157

158
    return vllm_linear_cls(
159
160
161
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
162
        quant_config=quant_config,
163
        prefix=prefix,
164
        return_bias=False,
165
        **vllm_linear_kwargs,
166
167
    )

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs)

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):

        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch, torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)))
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for torch_function_name, old_torch_function in (
                tensor_constructors_to_patch.items()):
            setattr(torch, torch_function_name, old_torch_function)


class MultiModalProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.model_config.hf_config

    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
231
232
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        mm_tokens = processor._get_num_multimodal_tokens(
            image_sizes=([height, width], ), **mm_processor_kwargs)
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size


class MultiModalDummyInputsBuilder(
        BaseDummyInputsBuilder[MultiModalProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
        }


class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
278
        out_mm_kwargs: MultiModalKwargsItems,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs,
        hf_processor_mm_kwargs,
        num_image_patches: torch.Tensor = None,
    ):
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        mm_fields = {
            key: MultiModalFieldConfig.flat_from_sizes("image",
                                                       num_image_patches)
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
            "image", num_image_patches)
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _apply_hf_processor_text_mm(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
319
    ) -> tuple[list[int], BatchFeature, bool]:
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        """
        Apply the HF processor on the prompt text and multi-modal data
        together.

        In addition, return whether prompt replacements have been applied.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
            tok_kwargs=tokenization_kwargs,
        )
        processed_data.update(passthrough_data)

        prompt_ids, = processed_data.pop("input_ids").tolist()
        mm_token_type_ids = processed_data.pop(
            "mm_token_type_ids"
        ) if "mm_token_type_ids" in processed_data else processed_data.pop(
            "token_type_ids")  # for gemma3 only

        return prompt_ids, processed_data, mm_token_type_ids

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
351
        mm_uuids: Optional[MultiModalUUIDDict] = None,
352
353
354
355
356
357
358
359
360
361
362
363
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
364
365
366
367
368
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)
369
370
371
372
373
374
375
376
377
378
379
380
381
382

        (prompt_ids, processed_data,
         mm_token_type_ids) = self._apply_hf_processor_text_mm(
             prompt_text=prompt,
             mm_items=mm_items,
             hf_processor_mm_kwargs=hf_processor_mm_kwargs,
             tokenization_kwargs=tokenization_kwargs,
         )

        # HF processor will return `mm_token_type_ids` from which
        # we can infer mm_placeholders. Until then hardcode to make code run
        # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
383
384
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
            image_sizes=image_sizes, **mm_processor_kwargs)

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool())
                for positions, mm_tokens in zip(chunked_mm_positions,
                                                chunked_mm_tokens)
            ]
            mm_placeholders = {"image": ranges}

        num_image_patches = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
        ) if "num_image_patches" in mm_tokens_per_modality else None
        processed_data['num_image_patches'] = num_image_patches
413
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
414
415
416
417
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
                                       num_image_patches),
        )
418

419
        # Use overrides if provided; fallback to data-dependent hashing.
420
421
422
423
        mm_hashes = self._hash_mm_items(mm_items,
                                        hf_processor_mm_kwargs,
                                        tokenization_kwargs,
                                        mm_uuids=mm_uuids)
424
425
426
427
428
429

        return MultiModalInputs(
            type="multimodal",
            prompt=prompt,
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
430
            mm_hashes=mm_hashes,
431
432
433
434
            mm_placeholders=mm_placeholders,
        )


435
436
437
438
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it
439

440
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
441
442
443
        super().__init__()
        logger.info("Using Transformers backend.")

444
445
446
447
448
449
450
        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
        self.quant_config: QuantizationConfig = vllm_config.quant_config
451
452
453
454
455
456

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

457
458
459
        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
        self.skip_substrs: list[str] = []
460

461
462
        # Set correct attn and init on "meta" to delay allocating GPU tensors
        # TODO: @raushan, use the public `model.set_attn_implementation()`
463
        # method once its checks are fixed in Transformers.
464
        self.text_config._attn_implementation = "vllm"
465
        with init_on_device_without_buffers("meta"):
466
            self.model: PreTrainedModel = AutoModel.from_config(
467
468
469
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
470
            )
471

472
473
474
475
476
477
478
        self.pipeline_parallel()
        self.tensor_parallel()

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
479
480
481
482
                    self.text_config.vocab_size,
                    self.text_config.hidden_size,
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
483
484
485
486
487
                ))

        # Attention layers
        self.attention_instances = self.create_attention_instances()

488
        # Initialize any parameters that have not had their modules replaced
489
490
        self.init_parameters(self.model)

491
        self.make_empty_intermediate_tensors = (
492
493
            make_empty_intermediate_tensors_factory(
                ["hidden_states"], self.text_config.hidden_size))
494
495
496
497
498
499
500

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
501

502
        if not self.model.supports_pp_plan:
503
504
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
505
            raise ValueError(
506
507
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
527
528
529
            if self.pp_group.is_first_rank or (
                    self.text_config.tie_word_embeddings
                    and self.pp_group.is_last_rank):
530
531
532
533
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
534
535
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers, self.pp_rank, self.pp_size)
536
537
538
539
540
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
541
            layers[i] = PPMissingLayer()
542
543
544
545
546
547
548
549
550
551
552
553

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def tensor_parallel(self):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
554
555
556
557
558
        # Look for tp plans in all of the PreTrainedModels found in self.model
        is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
        supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
        pretrained_models = filter(is_pretrained_model, self.model.modules())
        models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
559

560
        if not any(models_with_tp_plan) and self.tp_size > 1:
561
562
            tip = get_feature_request_tip(self.model_config.model,
                                          self.model_config.trust_remote_code)
563
            raise ValueError(
564
                f"{type(self.model)} does not support tensor parallel. {tip}")
565

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        def _tensor_parallel(module: nn.Module,
                             prefix: str = "",
                             tp_plan=None):
            tp_plan = tp_plan or {}

            # If the current module is a PreTrainedModel, set the tp_plan for
            # all of its children
            if isinstance(module, PreTrainedModel):
                tp_plan = module.config.base_model_tp_plan or {}
                tp_plan = {
                    maybe_prefix(prefix, k): v
                    for k, v in tp_plan.items()
                }

            # Some weight loaders expect linear layers to inherit from vLLM's
            # LinearBase class, so we set a default style which causes any
            # unspecified linear layers to be replaced with ReplicatedLinear
583
584
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
585
586
587
588
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
                    style = tp_plan.get(pattern, "replicate")
589
590
591
592
                    new_module = replace_linear_class(child_module,
                                                      style,
                                                      self.quant_config,
                                                      prefix=qual_name)
593
594
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)
595
                else:
596
597
598
                    _tensor_parallel(child_module,
                                     prefix=qual_name,
                                     tp_plan=tp_plan)
599
600
601

        _tensor_parallel(self.model)

602
603
604
605
    def create_attention_instances(
        self,
        attn_type: AttentionType = AttentionType.DECODER
    ) -> dict[int, Attention]:
606
607
608
609
610
611
612
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
613
        start, end = get_pp_indices(self.text_config.num_hidden_layers,
614
                                    self.pp_rank, self.pp_size)
615
616
617
618

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
619
620
621
622
            per_layer_sliding_window = None
            if (hasattr(self.config, "layer_types")
                    and self.config.layer_types[i] == "sliding_attention"):
                per_layer_sliding_window = self.config.sliding_window
623
624

            attention_instances[i] = Attention(
625
626
                num_heads=num_heads,
                head_size=head_size,
627
628
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
629
630
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
631
                cache_config=self.cache_config,
632
                quant_config=self.quant_config,
633
                per_layer_sliding_window=per_layer_sliding_window,
634
635
                prefix=f"{i}.attn",
                attn_type=attn_type)
636
        return attention_instances
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def init_parameters(self, module: nn.Module):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(param.data,
652
                                     dtype=self.model_config.dtype,
653
654
655
656
657
                                     device=self.device_config.device))
                setattr(module, name, new_param)
        for child in module.children():
            self.init_parameters(child)

658
659
    def forward(
        self,
660
        input_ids: Optional[torch.Tensor],
661
662
663
664
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
665
666
667
668
669
670
671
672
673
674
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

675
676
677
678
679
        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

680
681
682
        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
683
            use_cache=False,
684
            position_ids=position_ids,
685
686
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
687
688
689
690
691

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
692

693
694
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
695
696
697
698
699
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
        )
700
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
701
702


703
@support_torch_compile(enable_if=can_enable_torch_compile)
704
705
706
class TransformersModel(TransformersBase):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
707
708
            # Handle BERT-like models
            "bert": "model",
709
710
            # Add `model.` prefix for base model checkpoints
            "": "model.",
711
            # Remove `model.` prefix if it was already there
712
            "model.model.": "model.",
713
714
            # Pooling adapters will be adjacent to `model`
            "model.pooler": "pooler",
715
            "model.score": "score",
716
717
718
719
720
721
722
            # Classifier adapter's classifier layer is renamed to score
            "model.classifier": "score",
        },
        orig_to_new_suffix={
            # Replace legacy suffixes used for norms
            ".gamma": ".weight",
            ".beta": ".bias",
723
724
        })

725
726
727
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

728
729
730
731
732
733
734
735
736
        # After creating a pooling model, `pooler` will be duplicated.
        # The one inside `model` comes from the Transformers modelling code.
        # The one after `model` is an adapter from vLLM.
        # We want to use the adapter so we nullify the original pooler.
        if getattr(self.model, "pooler", None) is not None:
            self.skip_prefixes.append("pooler.")
            self.model.pooler = torch.nn.Identity()

        # Some encoder models have the position_ids buffer in the checkpoint.
737
738
739
740
        # vLLM will always pass position_ids as an argument, so we skip loading
        # the buffer if it exists
        self.skip_substrs.append("position_ids")

741
742
743
744
745
        # Some encoder models have the bias of the final classifier layer
        # in the checkpoint. vLLM does not use this bias, so we skip loading
        # it if it exists
        self.skip_substrs.append("score.bias")

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    def create_attention_instances(
            self, attn_type: AttentionType = AttentionType.DECODER):
        # TODO(hmellor): Better way to detect encoder models
        # In encoder models, the attention layers will have `is_causal=False`
        is_encoder = lambda m: not getattr(m, "is_causal", True)
        # vLLM does not support encoder-decoder models, so if any encoder layer
        # is found, we assume the whole model is an encoder model
        if any(is_encoder(m) for m in self.model.modules()):
            attn_type = AttentionType.ENCODER_ONLY

        # Check minimum transformers version for encoder models support
        if attn_type == AttentionType.ENCODER_ONLY:
            import transformers
            from packaging.version import Version
            installed = Version(transformers.__version__)
            required = Version("4.57.0.dev0")
            if installed < required:
                raise ValueError(
                    "Encoder models with the Transformers backend require "
                    f"transformers>={required}, but got {installed}")

        return super().create_attention_instances(attn_type)

769

770
@support_torch_compile(enable_if=can_enable_torch_compile)
771
class TransformersForCausalLM(TransformersBase):
772
773

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
774
        super().__init__(vllm_config=vllm_config, prefix=prefix)
775

776
777
778
        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
779
            self.skip_prefixes.append("lm_head.")
780
781

        if get_pp_group().is_last_rank:
782
            self.unpadded_vocab_size = self.text_config.vocab_size
783
            self.lm_head = ParallelLMHead(
784
785
786
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
787
788
                prefix=maybe_prefix(prefix, "lm_head"),
            )
789
            if self.text_config.tie_word_embeddings:
790
791
792
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

793
794
795
796
            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size,
                logit_scale)
797
798
799
800
801
802
803
        else:
            self.lm_head = PPMissingLayer()

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
804
        logits = self.logits_processor(self.lm_head, hidden_states)
805
806
        return logits

807

808
809
810
811
812
813
814
815
816
817
818
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
    """Flatten until a list of tensors can be concatenated then do concat"""

    def _can_concat(x: list[torch.Tensor]):
        return len(set(map(lambda _x: _x.shape[1:], x))) == 1

    if _can_concat(x):
        return torch.concat(x)
    return flatten_and_concat(flatten_bn(x))


819
820
821
822
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
    dummy_inputs=MultiModalDummyInputsBuilder)
823
@support_torch_compile(
824
    # set `positions` to last dim to support Qwen-mrope
825
826
827
828
829
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
830
831
    },
    enable_if=can_enable_torch_compile)
832
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
        })

856
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
857
        super().__init__(vllm_config=vllm_config, prefix=prefix)
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        if inputs_embeds is None:
            multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
            if multimodal_embeds is not None:
                inputs_embeds = self.get_input_embeddings(
                    input_ids, multimodal_embeds)
                input_ids = None

879
880
        model_output = super().forward(input_ids, positions,
                                       intermediate_tensors, inputs_embeds)
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        return model_output

    def get_multimodal_embeddings(self, **kwargs):
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
            "image_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if image_embeds is not None:
            return image_embeds

        if pixel_values is None and image_embeds is None:
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        if pixel_values is not None:
            if isinstance(pixel_values, torch.Tensor):
                pixel_values = flatten_bn(pixel_values).to(self.dtype)
            elif is_list_of(pixel_values, torch.Tensor):
900
                pixel_values = flatten_and_concat(pixel_values).to(self.dtype)
901
902
903
904
905
906
907
908
            else:
                raise ValueError(
                    f"Unsupported pixel_values type {type(pixel_values)}. "
                    "Expected `torch.Tensor` or list of `torch.Tensor`.")

            if isinstance(num_image_patches, list):
                num_image_patches = torch.cat(num_image_patches)

909
            vision_embeddings = self.model.get_image_features(
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
                pixel_values,
                **{
                    k: v.flatten(0, 1)
                    for k, v in kwargs.items()
                },
            )

            if isinstance(vision_embeddings, torch.Tensor):
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
                    vision_embeddings,
                    num_image_patches.flatten().tolist())
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings=None,
    ) -> torch.Tensor:
939
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
940
941
942
943
944
945
946
947
948
        if (multimodal_embeddings is not None
                and len(multimodal_embeddings) != 0):
            mask = (input_ids == self.config.image_token_id)
            mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
            multimodal_embeddings = torch.cat(multimodal_embeddings)

            inputs_embeds = inputs_embeds.masked_scatter(
                mask, multimodal_embeddings)
        return inputs_embeds