adapters.py 22.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable
6
from contextlib import contextmanager
7
from typing import TYPE_CHECKING, Any, TypeVar, cast
8
9
10
11

import torch
import torch.nn as nn

12
from vllm.config import VllmConfig
13
14
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
15
from vllm.model_executor.models.config import VerifyAndUpdateConfig
16
17
18
from vllm.transformers_utils.config import (
    try_get_dense_modules,
)
19
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
20

21
from .interfaces import supports_multimodal
22
from .interfaces_base import VllmModelForPooling, is_pooling_model
23

24
if TYPE_CHECKING:
25
    from vllm.config import ModelConfig, VllmConfig
26
    from vllm.model_executor.layers.pooler import Pooler
27

28
29
_T = TypeVar("_T", bound=type[nn.Module])

30
31
logger = init_logger(__name__)

32
33
34
35
36
37
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
38
39


40
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
41
42
    """Load Sentence-Transformers Dense projection layers."""

43
44
45
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
46

47
48
    if dense_modules is None:
        return
49

50
    try:
51
        layers = []
52
53
        for layer_config in dense_modules:
            folder = layer_config["folder"]
54
            linear = nn.Linear(
55
56
                layer_config["in_features"],
                layer_config["out_features"],
57
58
59
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
60
61
62
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
63
64
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
65
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
66
67
68
69
70
71
    except Exception:
        logger.exception("ST projector loading failed")

    return None


72
73
74
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
75
    """Load weights using vLLM's weight_loader pattern."""
76
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
77
78
79
80
81

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
82
83
84
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
85
86
87
88
89
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
90

91
92
93
                state_dict = load_safetensors(file_bytes)
            else:
                import io
94
95
96
97

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
98
99
100

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
101
102
103
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
104
                    weight_loader(linear.weight, state_dict[weight_key])
105
106
107

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
108
109
110
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
111
                        bias_loader(linear.bias, state_dict[bias_key])
112
113
114
115
116
117
118
119
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


120
121
122
123
124
125
126
127
128
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


129
130
131
132
def _create_pooling_model_cls(orig_cls: _T) -> _T:
    # Lazy import
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
133

134
    from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
135
136
137
138
139
140
141
142
143
144
145

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
146
147
148
149
150
151
152
153
154
            with no_init_weights(
                self,
                lambda mod: StageMissingLayer("output", mod),
                targets=(LogitsProcessor, ParallelLMHead),
            ):
                super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            # Used by SEQ_CLS_LOAD_METHODS
            self.vllm_config = vllm_config
155

156
157
158
159
160
161
162
            # If the model already defines a pooler instance, don't overwrite it
            pooler = getattr(self, "pooler", None)
            if not pooler and supports_multimodal(self):
                # Try to get the pooler from the LM backbone
                language_model = self.get_language_model()
                if hasattr(language_model, "pooler"):
                    pooler = language_model.pooler
163

164
165
            if not pooler:
                pooler = self._init_pooler(vllm_config, prefix=prefix)
166

167
            self.pooler = pooler
168

169
        def _init_pooler(
170
171
172
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
173
174
        ) -> "Pooler":
            raise NotImplementedError
175

176
177
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            params_dict = dict(self.named_parameters())
178

179
180
181
            # We support loading from both `*ForCausalLM` and `*Model`
            candidate_prefixes = ["", "model."]
            target_prefix = ""
182

183
184
            seen_weights = list[tuple[str, torch.Tensor]]()
            for name, loaded_weight in weights:
185
186
                # Clone because the iterator may reuse the tensor buffer
                seen_weights.append((name, loaded_weight.clone()))
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                try:
                    target_prefix = next(
                        prefix
                        for prefix in candidate_prefixes
                        if prefix + name in params_dict
                    )
                    break
                except StopIteration:
                    # The weight might not exist on the model
                    # (to be handled by AutoWeightsLoader)
                    pass

            if target_prefix:
                target_model = self
                for attr in target_prefix.split("."):
                    if attr:
                        target_model = getattr(self, attr)

                logger.info(
                    "Mapping weights to %s as they are "
                    "relative to this model instead of %s.",
                    target_model._get_name(),
                    self._get_name(),
211
                )
212

213
214
            # Lazy chain so buffer-reusing weight iterators (e.g.
            # runai_streamer) are consumed one tensor at a time.
215
216
            mapped_weights = (
                (target_prefix + name, weight)
217
                for name, weight in itertools.chain(seen_weights, weights)
218
            )
219

220
            def default_load_weights(weights):
221
                loader = AutoWeightsLoader(self)
222
                return loader.load_weights(weights)
223

224
225
226
            load_weights = getattr(super(), "load_weights", default_load_weights)
            return load_weights(mapped_weights)

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
246
    from vllm.model_executor.layers.pooler import DispatchPooler
247
248

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
249
250
251
252
253
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
254
255
256
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

257
            return DispatchPooler.for_embedding(pooler_config)
258

259
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
260
261

    return ModelForEmbedding  # type: ignore
262
263


264
def as_seq_cls_model(cls: _T) -> _T:
265
    """
266
    Subclass an existing vLLM model to support classify and score tasks.
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
281
    from vllm.model_executor.layers.linear import ReplicatedLinear
282
    from vllm.model_executor.layers.pooler import DispatchPooler
283
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
284

285
    from .utils import maybe_prefix
286

287
288
289
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
290
291
292
293
294
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
295
296
            hf_config = vllm_config.model_config.hf_config
            text_config = hf_config.get_text_config()
297
            model_config = vllm_config.model_config
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

            # Check if score weights are derived online from LM head
            # (same condition as load_weights branch)
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(
                hf_config,
                "method",
                getattr(text_config, "method", None),
            )

            # Online conversion: no score weights in checkpoint, don't
            # quantize (small output_dim breaks FP8/Marlin tile alignment).
            # Checkpoint-based: respect the model's quant_config.
            quant_config = (
                None
                if (tokens is not None or method is not None)
                else vllm_config.quant_config
            )
320

321
            self.score = ReplicatedLinear(
322
                model_config.get_hidden_size(),
323
                text_config.num_labels,
324
                bias=False,
325
                params_dtype=model_config.head_dtype,
326
                quant_config=quant_config,
327
                return_bias=False,
328
329
330
331
332
333
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

334
            return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
335

336
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
337
338
339
340
341
342
343
344
            hf_config = self.config
            text_config = hf_config.get_text_config()
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(hf_config, "method", getattr(text_config, "method", None))
345

346
347
348
349
350
351
352
353
354
355
356
357
            def auto_set_score_bias(weights):
                for name, weight in weights:
                    if name == "score.bias":
                        device = self.score.weight.device
                        dtype = self.score.weight.dtype
                        bias = weight.to(device).to(dtype)
                        self.score.bias = torch.nn.Parameter(bias)
                        self.score.skip_bias_add = False
                    else:
                        yield name, weight

            weights = auto_set_score_bias(weights)
358
359
360
361
362
363
364
            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

365
366
367
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
368

369
    return ModelForSequenceClassification  # type: ignore
370
371


372
373
374
class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
375
376
377
378
379
380
381
382
        hf_config = vllm_config.model_config.hf_config
        text_config = hf_config.get_text_config()
        method = getattr(hf_config, "method", getattr(text_config, "method", None))
        tokens = getattr(
            hf_config,
            "classifier_from_token",
            getattr(text_config, "classifier_from_token", None),
        )
383
384
385
386
387
388
389
390
391

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
392
            hf_config.num_labels = 1
393
            text_config.num_labels = 1
394
        else:
395
            hf_config.num_labels = len(tokens)
396
            text_config.num_labels = len(tokens)
397

398
399
400
        # `llm as reranker` defaults to not using separating token.
        use_sep_token = getattr(text_config, "use_sep_token", False)
        text_config.use_sep_token = use_sep_token
401

402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def _get_language_model_for_seq_cls(model) -> nn.Module:
    """
    Get the language model component for sequence classification conversion.
    For VLMs, returns the inner language model. For standard LLMs, returns model itself.
    """
    if supports_multimodal(model):
        try:
            lm = model.get_language_model()
            if lm is not model:
                return lm
        except Exception:
            pass

    for attr_name in ("language_model", "lm", "text_model"):
        if hasattr(model, attr_name):
            candidate = getattr(model, attr_name)
            if (
                isinstance(candidate, nn.Module)
                and candidate is not model
                and hasattr(candidate, "model")
            ):
                return candidate

    for name, child in model.named_children():
        child_name = type(child).__name__
        if ("ForCausalLM" in child_name or "LMHead" in child_name) and hasattr(
            child, "model"
        ):
            return child

    return model


@contextmanager
def _disable_seq_cls_loading_on_inner_model(language_model, is_vlm: bool):
    """
    Context manager to temporarily disable sequence classification loading
    on inner VLM models to prevent recursive seq_cls_model_loader calls.
    """
    if not is_vlm:
        yield
        return

    inner_hf_config = getattr(language_model, "config", None)
    if inner_hf_config is None:
        yield
        return

    inner_text_config = inner_hf_config.get_text_config()
    original_method = getattr(inner_text_config, "method", None)
    original_tokens = getattr(inner_text_config, "classifier_from_token", None)
    original_hf_tokens = getattr(inner_hf_config, "classifier_from_token", None)

    try:
        if original_method is not None:
            inner_text_config.method = None
        if original_tokens is not None:
            inner_text_config.classifier_from_token = None
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = None
        yield
    finally:
        if original_method is not None:
            inner_text_config.method = original_method
        if original_tokens is not None:
            inner_text_config.classifier_from_token = original_tokens
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = original_hf_tokens


473
def load_weights_using_from_2_way_softmax(
474
475
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
476
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
477
478
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
479

480
    model_config = model.vllm_config.model_config
481
482
    hf_config = model.config
    text_config = hf_config.get_text_config()
483

484
485
486
487
488
    tokens = getattr(
        hf_config,
        "classifier_from_token",
        getattr(text_config, "classifier_from_token", []),
    )
489
490
491
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

492
493
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
494
    using_vlm_head = is_vlm and hasattr(language_model, "score")
495

496
    language_model.lm_head = ParallelLMHead(
497
498
        text_config.vocab_size,
        text_config.hidden_size,
499
500
501
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
502
        # have this attribute, we fall back to get_input_embeddings(), which is used by
503
        # the Transformers modeling backend.
504
        text_backbone = language_model.model
505
        embed_tokens = (
506
507
508
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
509
        )
510
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
511

512
513
514
515
516
517
518
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
        # function, so we need use this hacky method to obtain it.
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        loaded_weights = pooling_model_cls.load_weights(model, weights)
519

520
    from vllm.tokenizers import get_tokenizer
521
522

    tokenizer = get_tokenizer(
523
524
525
526
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
527
    )
528
529
530

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
531
532
    lm_head_weight = language_model.lm_head.weight
    score_weight = lm_head_weight.data[[true_id]].to(
533
        torch.float32
534
    ) - lm_head_weight.data[[false_id]].to(torch.float32)
535

536
    score_layer = language_model.score if using_vlm_head else model.score
537
    param = score_layer.weight
538
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
539
    weight_loader(param, score_weight)
540

541
    del language_model.lm_head
542

543
544
545
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
546
    loaded_weights.add(score_weight_name)
547
548
549
550
551

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
552
553
554
    return loaded_weights


555
556
557
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
558

559
    model_config = model.vllm_config.model_config
560
561
562
    text_config = model.config.get_text_config()

    tokens = getattr(text_config, "classifier_from_token", [])
563
564
565
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

566
567
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
568
    using_vlm_head = is_vlm and hasattr(language_model, "score")
569
570

    language_model.lm_head = ParallelLMHead(
571
572
        text_config.vocab_size,
        text_config.hidden_size,
573
574
575
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
576
        # have this attribute, we fall back to get_input_embeddings(), which is used by
577
        # the Transformers modeling backend.
578
        text_backbone = language_model.model
579
        embed_tokens = (
580
581
582
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
583
        )
584
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
585

586
587
588
589
590
591
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
        loaded_weights = pooling_model_cls.load_weights(model, weights)
592

593
    from vllm.tokenizers import get_tokenizer
594
595

    tokenizer = get_tokenizer(
596
597
598
599
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
600
    )
601
602

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
603
    score_weight = language_model.lm_head.weight.data[token_ids]
604

605
    score_layer = language_model.score if using_vlm_head else model.score
606
    param = score_layer.weight
607
608
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
609

610
611
    del language_model.lm_head

612
613
614
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
615
616
617
618
619
620
    loaded_weights.add(score_weight_name)

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
621
622
623
    return loaded_weights


624
625
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
626
    "no_post_processing": load_weights_no_post_processing,
627
628
629
630
631
632
633
634
635
636
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
637
638
639
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
640

641
642
643
    hf_config = model.vllm_config.model_config.hf_config
    text_config = hf_config.get_text_config()
    method = getattr(hf_config, "method", getattr(text_config, "method", None))
644
645
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)