adapters.py 22.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from contextlib import contextmanager
6
from typing import TYPE_CHECKING, Any, TypeVar, cast
7
8
9
10

import torch
import torch.nn as nn

11
from vllm.config import VllmConfig
12
13
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
14
from vllm.model_executor.models.config import VerifyAndUpdateConfig
15
16
17
from vllm.transformers_utils.config import (
    try_get_dense_modules,
)
18
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
19

20
from .interfaces import supports_multimodal
21
from .interfaces_base import VllmModelForPooling, is_pooling_model
22

23
if TYPE_CHECKING:
24
    from vllm.config import ModelConfig, VllmConfig
25
    from vllm.model_executor.layers.pooler import Pooler
26

27
28
_T = TypeVar("_T", bound=type[nn.Module])

29
30
logger = init_logger(__name__)

31
32
33
34
35
36
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
37
38


39
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
40
41
    """Load Sentence-Transformers Dense projection layers."""

42
43
44
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
45

46
47
    if dense_modules is None:
        return
48

49
    try:
50
        layers = []
51
52
        for layer_config in dense_modules:
            folder = layer_config["folder"]
53
            linear = nn.Linear(
54
55
                layer_config["in_features"],
                layer_config["out_features"],
56
57
58
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
59
60
61
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
62
63
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
64
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
65
66
67
68
69
70
    except Exception:
        logger.exception("ST projector loading failed")

    return None


71
72
73
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
74
    """Load weights using vLLM's weight_loader pattern."""
75
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
76
77
78
79
80

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
81
82
83
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
84
85
86
87
88
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
89

90
91
92
                state_dict = load_safetensors(file_bytes)
            else:
                import io
93
94
95
96

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
97
98
99

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
100
101
102
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
103
                    weight_loader(linear.weight, state_dict[weight_key])
104
105
106

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
107
108
109
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
110
                        bias_loader(linear.bias, state_dict[bias_key])
111
112
113
114
115
116
117
118
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


119
120
121
122
123
124
125
126
127
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


128
129
130
131
def _create_pooling_model_cls(orig_cls: _T) -> _T:
    # Lazy import
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
132

133
    from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
134
135
136
137
138
139
140
141
142
143
144

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
145
146
147
148
149
150
151
152
153
            with no_init_weights(
                self,
                lambda mod: StageMissingLayer("output", mod),
                targets=(LogitsProcessor, ParallelLMHead),
            ):
                super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            # Used by SEQ_CLS_LOAD_METHODS
            self.vllm_config = vllm_config
154

155
156
157
158
159
160
161
            # If the model already defines a pooler instance, don't overwrite it
            pooler = getattr(self, "pooler", None)
            if not pooler and supports_multimodal(self):
                # Try to get the pooler from the LM backbone
                language_model = self.get_language_model()
                if hasattr(language_model, "pooler"):
                    pooler = language_model.pooler
162

163
164
            if not pooler:
                pooler = self._init_pooler(vllm_config, prefix=prefix)
165

166
            self.pooler = pooler
167

168
        def _init_pooler(
169
170
171
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
172
173
        ) -> "Pooler":
            raise NotImplementedError
174

175
176
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            params_dict = dict(self.named_parameters())
177

178
179
180
            # We support loading from both `*ForCausalLM` and `*Model`
            candidate_prefixes = ["", "model."]
            target_prefix = ""
181

182
183
184
            seen_weights = list[tuple[str, torch.Tensor]]()
            for name, loaded_weight in weights:
                seen_weights.append((name, loaded_weight))
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                try:
                    target_prefix = next(
                        prefix
                        for prefix in candidate_prefixes
                        if prefix + name in params_dict
                    )
                    break
                except StopIteration:
                    # The weight might not exist on the model
                    # (to be handled by AutoWeightsLoader)
                    pass

            if target_prefix:
                target_model = self
                for attr in target_prefix.split("."):
                    if attr:
                        target_model = getattr(self, attr)

                logger.info(
                    "Mapping weights to %s as they are "
                    "relative to this model instead of %s.",
                    target_model._get_name(),
                    self._get_name(),
209
                )
210

211
212
213
214
            mapped_weights = (
                (target_prefix + name, weight)
                for name, weight in (*seen_weights, *weights)
            )
215

216
            def default_load_weights(weights):
217
                loader = AutoWeightsLoader(self)
218
                return loader.load_weights(weights)
219

220
221
222
            load_weights = getattr(super(), "load_weights", default_load_weights)
            return load_weights(mapped_weights)

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
242
    from vllm.model_executor.layers.pooler import DispatchPooler
243
244

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
245
246
247
248
249
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
250
251
252
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

253
            return DispatchPooler.for_embedding(pooler_config)
254

255
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
256
257

    return ModelForEmbedding  # type: ignore
258
259


260
def as_seq_cls_model(cls: _T) -> _T:
261
    """
262
    Subclass an existing vLLM model to support classify and score tasks.
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
277
    from vllm.model_executor.layers.linear import ReplicatedLinear
278
    from vllm.model_executor.layers.pooler import DispatchPooler
279
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
280

281
    from .utils import maybe_prefix
282

283
284
285
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
286
287
288
289
290
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
291
292
            hf_config = vllm_config.model_config.hf_config
            text_config = hf_config.get_text_config()
293
            model_config = vllm_config.model_config
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

            # Check if score weights are derived online from LM head
            # (same condition as load_weights branch)
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(
                hf_config,
                "method",
                getattr(text_config, "method", None),
            )

            # Online conversion: no score weights in checkpoint, don't
            # quantize (small output_dim breaks FP8/Marlin tile alignment).
            # Checkpoint-based: respect the model's quant_config.
            quant_config = (
                None
                if (tokens is not None or method is not None)
                else vllm_config.quant_config
            )
316

317
            self.score = ReplicatedLinear(
318
                model_config.get_hidden_size(),
319
                text_config.num_labels,
320
                bias=False,
321
                params_dtype=model_config.head_dtype,
322
                quant_config=quant_config,
323
                return_bias=False,
324
325
326
327
328
329
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

330
            return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
331

332
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
333
334
335
336
337
338
339
340
            hf_config = self.config
            text_config = hf_config.get_text_config()
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(hf_config, "method", getattr(text_config, "method", None))
341

342
343
344
345
346
347
348
349
350
351
352
353
            def auto_set_score_bias(weights):
                for name, weight in weights:
                    if name == "score.bias":
                        device = self.score.weight.device
                        dtype = self.score.weight.dtype
                        bias = weight.to(device).to(dtype)
                        self.score.bias = torch.nn.Parameter(bias)
                        self.score.skip_bias_add = False
                    else:
                        yield name, weight

            weights = auto_set_score_bias(weights)
354
355
356
357
358
359
360
            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

361
362
363
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
364

365
    return ModelForSequenceClassification  # type: ignore
366
367


368
369
370
class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
371
372
373
374
375
376
377
378
        hf_config = vllm_config.model_config.hf_config
        text_config = hf_config.get_text_config()
        method = getattr(hf_config, "method", getattr(text_config, "method", None))
        tokens = getattr(
            hf_config,
            "classifier_from_token",
            getattr(text_config, "classifier_from_token", None),
        )
379
380
381
382
383
384
385
386
387

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
388
            hf_config.num_labels = 1
389
            text_config.num_labels = 1
390
        else:
391
            hf_config.num_labels = len(tokens)
392
            text_config.num_labels = len(tokens)
393

394
395
396
        # `llm as reranker` defaults to not using separating token.
        use_sep_token = getattr(text_config, "use_sep_token", False)
        text_config.use_sep_token = use_sep_token
397

398

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def _get_language_model_for_seq_cls(model) -> nn.Module:
    """
    Get the language model component for sequence classification conversion.
    For VLMs, returns the inner language model. For standard LLMs, returns model itself.
    """
    if supports_multimodal(model):
        try:
            lm = model.get_language_model()
            if lm is not model:
                return lm
        except Exception:
            pass

    for attr_name in ("language_model", "lm", "text_model"):
        if hasattr(model, attr_name):
            candidate = getattr(model, attr_name)
            if (
                isinstance(candidate, nn.Module)
                and candidate is not model
                and hasattr(candidate, "model")
            ):
                return candidate

    for name, child in model.named_children():
        child_name = type(child).__name__
        if ("ForCausalLM" in child_name or "LMHead" in child_name) and hasattr(
            child, "model"
        ):
            return child

    return model


@contextmanager
def _disable_seq_cls_loading_on_inner_model(language_model, is_vlm: bool):
    """
    Context manager to temporarily disable sequence classification loading
    on inner VLM models to prevent recursive seq_cls_model_loader calls.
    """
    if not is_vlm:
        yield
        return

    inner_hf_config = getattr(language_model, "config", None)
    if inner_hf_config is None:
        yield
        return

    inner_text_config = inner_hf_config.get_text_config()
    original_method = getattr(inner_text_config, "method", None)
    original_tokens = getattr(inner_text_config, "classifier_from_token", None)
    original_hf_tokens = getattr(inner_hf_config, "classifier_from_token", None)

    try:
        if original_method is not None:
            inner_text_config.method = None
        if original_tokens is not None:
            inner_text_config.classifier_from_token = None
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = None
        yield
    finally:
        if original_method is not None:
            inner_text_config.method = original_method
        if original_tokens is not None:
            inner_text_config.classifier_from_token = original_tokens
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = original_hf_tokens


469
def load_weights_using_from_2_way_softmax(
470
471
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
472
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
473
474
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
475

476
    model_config = model.vllm_config.model_config
477
478
    hf_config = model.config
    text_config = hf_config.get_text_config()
479

480
481
482
483
484
    tokens = getattr(
        hf_config,
        "classifier_from_token",
        getattr(text_config, "classifier_from_token", []),
    )
485
486
487
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

488
489
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
490
    using_vlm_head = is_vlm and hasattr(language_model, "score")
491

492
    language_model.lm_head = ParallelLMHead(
493
494
        text_config.vocab_size,
        text_config.hidden_size,
495
496
497
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
498
        # have this attribute, we fall back to get_input_embeddings(), which is used by
499
        # the Transformers modeling backend.
500
        text_backbone = language_model.model
501
        embed_tokens = (
502
503
504
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
505
        )
506
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
507

508
509
510
511
512
513
514
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
        # function, so we need use this hacky method to obtain it.
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        loaded_weights = pooling_model_cls.load_weights(model, weights)
515

516
    from vllm.tokenizers import get_tokenizer
517
518

    tokenizer = get_tokenizer(
519
520
521
522
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
523
    )
524
525
526

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
527
528
    lm_head_weight = language_model.lm_head.weight
    score_weight = lm_head_weight.data[[true_id]].to(
529
        torch.float32
530
    ) - lm_head_weight.data[[false_id]].to(torch.float32)
531

532
    score_layer = language_model.score if using_vlm_head else model.score
533
    param = score_layer.weight
534
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
535
    weight_loader(param, score_weight)
536

537
    del language_model.lm_head
538

539
540
541
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
542
    loaded_weights.add(score_weight_name)
543
544
545
546
547

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
548
549
550
    return loaded_weights


551
552
553
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
554

555
    model_config = model.vllm_config.model_config
556
557
558
    text_config = model.config.get_text_config()

    tokens = getattr(text_config, "classifier_from_token", [])
559
560
561
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

562
563
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
564
    using_vlm_head = is_vlm and hasattr(language_model, "score")
565
566

    language_model.lm_head = ParallelLMHead(
567
568
        text_config.vocab_size,
        text_config.hidden_size,
569
570
571
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
572
        # have this attribute, we fall back to get_input_embeddings(), which is used by
573
        # the Transformers modeling backend.
574
        text_backbone = language_model.model
575
        embed_tokens = (
576
577
578
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
579
        )
580
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
581

582
583
584
585
586
587
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
        loaded_weights = pooling_model_cls.load_weights(model, weights)
588

589
    from vllm.tokenizers import get_tokenizer
590
591

    tokenizer = get_tokenizer(
592
593
594
595
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
596
    )
597
598

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
599
    score_weight = language_model.lm_head.weight.data[token_ids]
600

601
    score_layer = language_model.score if using_vlm_head else model.score
602
    param = score_layer.weight
603
604
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
605

606
607
    del language_model.lm_head

608
609
610
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
611
612
613
614
615
616
    loaded_weights.add(score_weight_name)

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
617
618
619
    return loaded_weights


620
621
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
622
    "no_post_processing": load_weights_no_post_processing,
623
624
625
626
627
628
629
630
631
632
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
633
634
635
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
636

637
638
639
    hf_config = model.vllm_config.model_config.hf_config
    text_config = hf_config.get_text_config()
    method = getattr(hf_config, "method", getattr(text_config, "method", None))
640
641
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)