adapters.py 21.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from contextlib import contextmanager
6
from typing import TYPE_CHECKING, Any, TypeVar, cast
7
8
9
10

import torch
import torch.nn as nn

11
from vllm.config import VllmConfig
12
13
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
14
from vllm.model_executor.models.config import VerifyAndUpdateConfig
15
16
17
from vllm.transformers_utils.config import (
    try_get_dense_modules,
)
18
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
19

20
from .interfaces import supports_multimodal
21
from .interfaces_base import VllmModelForPooling, is_pooling_model
22

23
if TYPE_CHECKING:
24
    from vllm.config import ModelConfig, VllmConfig
25
    from vllm.model_executor.layers.pooler import Pooler
26

27
28
_T = TypeVar("_T", bound=type[nn.Module])

29
30
logger = init_logger(__name__)

31
32
33
34
35
36
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
37
38


39
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
40
41
    """Load Sentence-Transformers Dense projection layers."""

42
43
44
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
45

46
47
    if dense_modules is None:
        return
48

49
    try:
50
        layers = []
51
52
        for layer_config in dense_modules:
            folder = layer_config["folder"]
53
            linear = nn.Linear(
54
55
                layer_config["in_features"],
                layer_config["out_features"],
56
57
58
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
59
60
61
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
62
63
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
64
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
65
66
67
68
69
70
    except Exception:
        logger.exception("ST projector loading failed")

    return None


71
72
73
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
74
    """Load weights using vLLM's weight_loader pattern."""
75
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
76
77
78
79
80

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
81
82
83
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
84
85
86
87
88
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
89

90
91
92
                state_dict = load_safetensors(file_bytes)
            else:
                import io
93
94
95
96

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
97
98
99

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
100
101
102
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
103
                    weight_loader(linear.weight, state_dict[weight_key])
104
105
106

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
107
108
109
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
110
                        bias_loader(linear.bias, state_dict[bias_key])
111
112
113
114
115
116
117
118
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


119
120
121
122
123
124
125
126
127
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


128
129
130
131
def _create_pooling_model_cls(orig_cls: _T) -> _T:
    # Lazy import
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
132

133
    from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
134
135
136
137
138
139
140
141
142
143
144

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
145
146
147
148
149
150
151
152
153
            with no_init_weights(
                self,
                lambda mod: StageMissingLayer("output", mod),
                targets=(LogitsProcessor, ParallelLMHead),
            ):
                super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            # Used by SEQ_CLS_LOAD_METHODS
            self.vllm_config = vllm_config
154

155
156
157
158
159
160
161
            # If the model already defines a pooler instance, don't overwrite it
            pooler = getattr(self, "pooler", None)
            if not pooler and supports_multimodal(self):
                # Try to get the pooler from the LM backbone
                language_model = self.get_language_model()
                if hasattr(language_model, "pooler"):
                    pooler = language_model.pooler
162

163
164
            if not pooler:
                pooler = self._init_pooler(vllm_config, prefix=prefix)
165

166
            self.pooler = pooler
167

168
        def _init_pooler(
169
170
171
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
172
173
        ) -> "Pooler":
            raise NotImplementedError
174

175
176
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            params_dict = dict(self.named_parameters())
177

178
179
180
            # We support loading from both `*ForCausalLM` and `*Model`
            candidate_prefixes = ["", "model."]
            target_prefix = ""
181

182
183
184
            seen_weights = list[tuple[str, torch.Tensor]]()
            for name, loaded_weight in weights:
                seen_weights.append((name, loaded_weight))
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                try:
                    target_prefix = next(
                        prefix
                        for prefix in candidate_prefixes
                        if prefix + name in params_dict
                    )
                    break
                except StopIteration:
                    # The weight might not exist on the model
                    # (to be handled by AutoWeightsLoader)
                    pass

            if target_prefix:
                target_model = self
                for attr in target_prefix.split("."):
                    if attr:
                        target_model = getattr(self, attr)

                logger.info(
                    "Mapping weights to %s as they are "
                    "relative to this model instead of %s.",
                    target_model._get_name(),
                    self._get_name(),
209
                )
210

211
212
213
214
            mapped_weights = (
                (target_prefix + name, weight)
                for name, weight in (*seen_weights, *weights)
            )
215

216
            def default_load_weights(weights):
217
                loader = AutoWeightsLoader(self)
218
                return loader.load_weights(weights)
219

220
221
222
            load_weights = getattr(super(), "load_weights", default_load_weights)
            return load_weights(mapped_weights)

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
242
    from vllm.model_executor.layers.pooler import DispatchPooler
243
244

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
245
246
247
248
249
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
250
251
252
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

253
            return DispatchPooler.for_embedding(pooler_config)
254

255
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
256
257

    return ModelForEmbedding  # type: ignore
258
259


260
def as_seq_cls_model(cls: _T) -> _T:
261
    """
262
    Subclass an existing vLLM model to support classify and score tasks.
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
277
    from vllm.model_executor.layers.linear import ReplicatedLinear
278
    from vllm.model_executor.layers.pooler import DispatchPooler
279
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
280

281
    from .utils import maybe_prefix
282

283
284
285
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
286
287
288
289
290
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
291
            text_config = vllm_config.model_config.hf_config.get_text_config()
292
            model_config = vllm_config.model_config
293
294
            quant_config = vllm_config.quant_config

295
            self.score = ReplicatedLinear(
296
                model_config.get_hidden_size(),
297
                text_config.num_labels,
298
                bias=False,
299
                params_dtype=vllm_config.model_config.head_dtype,
300
                quant_config=quant_config,
301
                return_bias=False,
302
303
304
305
306
307
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

308
            return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
309

310
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
311
312
313
314
315
316
317
318
            hf_config = self.config
            text_config = hf_config.get_text_config()
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(hf_config, "method", getattr(text_config, "method", None))
319

320
321
322
323
324
325
326
327
328
329
330
331
            def auto_set_score_bias(weights):
                for name, weight in weights:
                    if name == "score.bias":
                        device = self.score.weight.device
                        dtype = self.score.weight.dtype
                        bias = weight.to(device).to(dtype)
                        self.score.bias = torch.nn.Parameter(bias)
                        self.score.skip_bias_add = False
                    else:
                        yield name, weight

            weights = auto_set_score_bias(weights)
332
333
334
335
336
337
338
            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

339
340
341
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
342

343
    return ModelForSequenceClassification  # type: ignore
344
345


346
347
348
class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
349
350
351
352
353
354
355
356
        hf_config = vllm_config.model_config.hf_config
        text_config = hf_config.get_text_config()
        method = getattr(hf_config, "method", getattr(text_config, "method", None))
        tokens = getattr(
            hf_config,
            "classifier_from_token",
            getattr(text_config, "classifier_from_token", None),
        )
357
358
359
360
361
362
363
364
365

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
366
            hf_config.num_labels = 1
367
            text_config.num_labels = 1
368
        else:
369
            hf_config.num_labels = len(tokens)
370
            text_config.num_labels = len(tokens)
371

372
373
374
        # `llm as reranker` defaults to not using separating token.
        use_sep_token = getattr(text_config, "use_sep_token", False)
        text_config.use_sep_token = use_sep_token
375

376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def _get_language_model_for_seq_cls(model) -> nn.Module:
    """
    Get the language model component for sequence classification conversion.
    For VLMs, returns the inner language model. For standard LLMs, returns model itself.
    """
    if supports_multimodal(model):
        try:
            lm = model.get_language_model()
            if lm is not model:
                return lm
        except Exception:
            pass

    for attr_name in ("language_model", "lm", "text_model"):
        if hasattr(model, attr_name):
            candidate = getattr(model, attr_name)
            if (
                isinstance(candidate, nn.Module)
                and candidate is not model
                and hasattr(candidate, "model")
            ):
                return candidate

    for name, child in model.named_children():
        child_name = type(child).__name__
        if ("ForCausalLM" in child_name or "LMHead" in child_name) and hasattr(
            child, "model"
        ):
            return child

    return model


@contextmanager
def _disable_seq_cls_loading_on_inner_model(language_model, is_vlm: bool):
    """
    Context manager to temporarily disable sequence classification loading
    on inner VLM models to prevent recursive seq_cls_model_loader calls.
    """
    if not is_vlm:
        yield
        return

    inner_hf_config = getattr(language_model, "config", None)
    if inner_hf_config is None:
        yield
        return

    inner_text_config = inner_hf_config.get_text_config()
    original_method = getattr(inner_text_config, "method", None)
    original_tokens = getattr(inner_text_config, "classifier_from_token", None)
    original_hf_tokens = getattr(inner_hf_config, "classifier_from_token", None)

    try:
        if original_method is not None:
            inner_text_config.method = None
        if original_tokens is not None:
            inner_text_config.classifier_from_token = None
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = None
        yield
    finally:
        if original_method is not None:
            inner_text_config.method = original_method
        if original_tokens is not None:
            inner_text_config.classifier_from_token = original_tokens
        if original_hf_tokens is not None:
            inner_hf_config.classifier_from_token = original_hf_tokens


447
def load_weights_using_from_2_way_softmax(
448
449
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
450
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
451
452
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
453

454
    model_config = model.vllm_config.model_config
455
    quant_config = model.vllm_config.quant_config
456
457
    hf_config = model.config
    text_config = hf_config.get_text_config()
458

459
460
461
462
463
    tokens = getattr(
        hf_config,
        "classifier_from_token",
        getattr(text_config, "classifier_from_token", []),
    )
464
465
466
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

467
468
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
469
    using_vlm_head = is_vlm and hasattr(language_model, "score")
470

471
    language_model.lm_head = ParallelLMHead(
472
473
474
475
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
476
        # have this attribute, we fall back to get_input_embeddings(), which is used by
477
        # the Transformers modeling backend.
478
        text_backbone = language_model.model
479
        embed_tokens = (
480
481
482
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
483
        )
484
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
485

486
487
488
489
490
491
492
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
        # function, so we need use this hacky method to obtain it.
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        loaded_weights = pooling_model_cls.load_weights(model, weights)
493

494
    from vllm.tokenizers import get_tokenizer
495
496

    tokenizer = get_tokenizer(
497
498
499
500
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
501
    )
502
503
504

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
505
506
    lm_head_weight = language_model.lm_head.weight
    score_weight = lm_head_weight.data[[true_id]].to(
507
        torch.float32
508
    ) - lm_head_weight.data[[false_id]].to(torch.float32)
509

510
    score_layer = language_model.score if using_vlm_head else model.score
511
    param = score_layer.weight
512
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
513
    weight_loader(param, score_weight)
514

515
    del language_model.lm_head
516

517
518
519
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
520
    loaded_weights.add(score_weight_name)
521
522
523
524
525

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
526
527
528
    return loaded_weights


529
530
531
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
532

533
    model_config = model.vllm_config.model_config
534
535
536
537
    quant_config = model.vllm_config.quant_config
    text_config = model.config.get_text_config()

    tokens = getattr(text_config, "classifier_from_token", [])
538
539
540
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

541
542
    language_model = _get_language_model_for_seq_cls(model)
    is_vlm = language_model is not model
543
    using_vlm_head = is_vlm and hasattr(language_model, "score")
544
545

    language_model.lm_head = ParallelLMHead(
546
547
548
549
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
550
        # have this attribute, we fall back to get_input_embeddings(), which is used by
551
        # the Transformers modeling backend.
552
        text_backbone = language_model.model
553
        embed_tokens = (
554
555
556
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
557
        )
558
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
559

560
561
562
563
564
565
    with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
        pooling_model_cls = next(
            x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
        )
        # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
        loaded_weights = pooling_model_cls.load_weights(model, weights)
566

567
    from vllm.tokenizers import get_tokenizer
568
569

    tokenizer = get_tokenizer(
570
571
572
573
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
574
    )
575
576

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
577
    score_weight = language_model.lm_head.weight.data[token_ids]
578

579
    score_layer = language_model.score if using_vlm_head else model.score
580
    param = score_layer.weight
581
582
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
583

584
585
    del language_model.lm_head

586
587
588
    score_weight_name = (
        "language_model.score.weight" if using_vlm_head else "score.weight"
    )
589
590
591
592
593
594
    loaded_weights.add(score_weight_name)

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
595
596
597
    return loaded_weights


598
599
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
600
    "no_post_processing": load_weights_no_post_processing,
601
602
603
604
605
606
607
608
609
610
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
611
612
613
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
614

615
616
617
    hf_config = model.vllm_config.model_config.hf_config
    text_config = hf_config.get_text_config()
    method = getattr(hf_config, "method", getattr(text_config, "method", None))
618
619
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)