adapters.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import ast
import inspect
6
from collections.abc import Iterable
7
from typing import TYPE_CHECKING, Any, TypeVar, cast
8
9
10
11

import torch
import torch.nn as nn

12
from vllm.config import VllmConfig
13
14
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
15
from vllm.model_executor.models.config import VerifyAndUpdateConfig
16
17
18
from vllm.transformers_utils.config import (
    try_get_dense_modules,
)
19
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
20

21
from .interfaces_base import VllmModelForPooling, is_pooling_model
22

23
if TYPE_CHECKING:
24
    from vllm.config import ModelConfig, VllmConfig
25

26
27
_T = TypeVar("_T", bound=type[nn.Module])

28
29
logger = init_logger(__name__)

30
31
32
33
34
35
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
36
37


38
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
39
40
    """Load Sentence-Transformers Dense projection layers."""

41
42
43
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
44

45
46
    if dense_modules is None:
        return
47

48
    try:
49
        layers = []
50
51
        for layer_config in dense_modules:
            folder = layer_config["folder"]
52
            linear = nn.Linear(
53
54
                layer_config["in_features"],
                layer_config["out_features"],
55
56
57
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
58
59
60
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
61
62
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
63
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
64
65
66
67
68
69
    except Exception:
        logger.exception("ST projector loading failed")

    return None


70
71
72
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
73
    """Load weights using vLLM's weight_loader pattern."""
74
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
75
76
77
78
79

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
80
81
82
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
83
84
85
86
87
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
88

89
90
91
                state_dict = load_safetensors(file_bytes)
            else:
                import io
92
93
94
95

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
96
97
98

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
99
100
101
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
102
                    weight_loader(linear.weight, state_dict[weight_key])
103
104
105

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
106
107
108
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
109
                        bias_loader(linear.bias, state_dict[bias_key])
110
111
112
113
114
115
116
117
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


118
119
120
121
122
123
124
125
126
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
    class CallVisitor(ast.NodeVisitor):
        def __init__(self):
            self.calls = []

        def visit_Call(self, node):
            if isinstance(node.func, ast.Name):
                self.calls.append(node.func.id)
            self.generic_visit(node)

    visitor = CallVisitor()
    visitor.visit(ast.parse(inspect.getsource(orig_cls)))
    if "init_vllm_registered_model" not in visitor.calls:
        return None

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            self.pooler = self.get_language_model().pooler

    return ModelForPooling  # type: ignore


159
def _create_pooling_model_cls(orig_cls: _T) -> _T:
160
161
162
    # Lazy import
    from .utils import AutoWeightsLoader, WeightsMapper

163
    class ModelForPooling(orig_cls, VllmModelForPooling):
164
165
        is_pooling_model = True

166
167
168
169
170
171
172
173
174
        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

175
176
            self.vllm_config = vllm_config

177
            # These are not used in pooling models
178
179
180
181
182
183
184
185
            objects_to_clean = [self]
            if language_model := getattr(self, "language_model", None):
                objects_to_clean.append(language_model)

            for obj in objects_to_clean:
                for attr in ("lm_head", "logits_processor"):
                    if hasattr(obj, attr):
                        delattr(obj, attr)
186

187
            # If the model already defines a pooler instance, don't overwrite it
188
            if not getattr(self, "pooler", None):
189
190
191
                self._init_pooler(vllm_config, prefix=prefix)

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
192
            raise NotImplementedError
193

194
195
196
197
198
        def load_weights(
            self,
            weights: Iterable[tuple[str, torch.Tensor]],
            load_lm_head: bool = False,
        ):
199
200
            # TODO: Support uninitialized params tracking

201
202
203
204
205
206
207
208
            # For most pooling models: We have deleted this attribute, so don't load it.
            # For converting an LLM into a seq cls model, we need the lm_head.
            if not load_lm_head:
                weights = (
                    (name, data)
                    for name, data in weights
                    if not name.startswith("lm_head.")
                )
209
210
211
212
213
214
215
216

            # If `*ForCausalLM` defines `load_weights` on the inner model
            # and there are no other inner modules with parameters,
            # we support loading from both `*Model` and `*ForCausalLM`
            if hasattr(self, "model") and hasattr(self.model, "load_weights"):
                # Whether only `self.model` contains parameters
                model_is_only_param = all(
                    name == "model" or next(child.parameters(), None) is None
217
218
                    for name, child in self.named_children()
                )
219
220
221
222
223

                if model_is_only_param:
                    mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
                    weights = mapper.apply(weights)

224
225
226
                    loaded_params = self.model.load_weights(weights)
                    loaded_params = {f"model.{name}" for name in loaded_params}
                    return loaded_params
227
228

            # For most other models
229
            if hasattr(orig_cls, "load_weights"):
230
                return orig_cls.load_weights(self, weights)  # type: ignore
231
232
233
            # Fallback
            else:
                loader = AutoWeightsLoader(self)
234
                return loader.load_weights(weights)
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
255
    from vllm.model_executor.layers.pooler import DispatchPooler
256
257
258
259
260
261

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

262
            self.pooler = DispatchPooler.for_embedding(pooler_config)
263

264
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
265
266

    return ModelForEmbedding  # type: ignore
267
268


269
def as_seq_cls_model(cls: _T) -> _T:
270
    """
271
    Subclass an existing vLLM model to support classify and score tasks.
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
286
    from vllm.model_executor.layers.linear import ReplicatedLinear
287
    from vllm.model_executor.layers.pooler import DispatchPooler
288
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
289

290
    from .utils import maybe_prefix
291

292
293
294
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
295
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
296
            text_config = vllm_config.model_config.hf_config.get_text_config()
297
            model_config = vllm_config.model_config
298
299
            quant_config = vllm_config.quant_config

300
            self.score = ReplicatedLinear(
301
                model_config.get_hidden_size(),
302
                text_config.num_labels,
303
                bias=False,
304
                params_dtype=vllm_config.model_config.head_dtype,
305
                quant_config=quant_config,
306
                return_bias=False,
307
308
309
310
311
312
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

313
314
            self.pooler = DispatchPooler.for_seq_cls(
                pooler_config, classifier=self.score
315
            )
316

317
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
318
319
320
321
322
323
324
325
            hf_config = self.config
            text_config = hf_config.get_text_config()
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(hf_config, "method", getattr(text_config, "method", None))
326

327
328
329
330
331
332
333
334
335
336
337
338
            def auto_set_score_bias(weights):
                for name, weight in weights:
                    if name == "score.bias":
                        device = self.score.weight.device
                        dtype = self.score.weight.dtype
                        bias = weight.to(device).to(dtype)
                        self.score.bias = torch.nn.Parameter(bias)
                        self.score.skip_bias_add = False
                    else:
                        yield name, weight

            weights = auto_set_score_bias(weights)
339
340
341
342
343
344
345
            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

346
347
348
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
349

350
    return ModelForSequenceClassification  # type: ignore
351
352


353
354
355
class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
356
357
358
359
360
361
362
363
        hf_config = vllm_config.model_config.hf_config
        text_config = hf_config.get_text_config()
        method = getattr(hf_config, "method", getattr(text_config, "method", None))
        tokens = getattr(
            hf_config,
            "classifier_from_token",
            getattr(text_config, "classifier_from_token", None),
        )
364
365
366
367
368
369
370
371
372

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
373
            hf_config.num_labels = 1
374
            text_config.num_labels = 1
375
        else:
376
            hf_config.num_labels = len(tokens)
377
            text_config.num_labels = len(tokens)
378

379
380
381
        # `llm as reranker` defaults to not using separating token.
        use_sep_token = getattr(text_config, "use_sep_token", False)
        text_config.use_sep_token = use_sep_token
382

383
384

def load_weights_using_from_2_way_softmax(
385
386
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
387
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
388
389
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
390

391
    model_config = model.vllm_config.model_config
392
    quant_config = model.vllm_config.quant_config
393
394
    hf_config = model.config
    text_config = hf_config.get_text_config()
395

396
397
398
399
400
    tokens = getattr(
        hf_config,
        "classifier_from_token",
        getattr(text_config, "classifier_from_token", []),
    )
401
402
403
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

404
405
406
407
    language_model = (
        model.get_language_model() if hasattr(model, "get_language_model") else model
    )
    language_model.lm_head = ParallelLMHead(
408
409
410
411
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
412
        # have this attribute, we fall back to get_input_embeddings(), which is used by
413
        # the Transformers modeling backend.
414
        text_backbone = language_model.model
415
        embed_tokens = (
416
417
418
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
419
        )
420
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
421

422
423
424
425
426
427
    # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
    # function, so we need use this hacky method to obtain it.
    pooling_model_cls = next(
        x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
    )
    loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
428

429
    from vllm.tokenizers import get_tokenizer
430
431

    tokenizer = get_tokenizer(
432
433
434
435
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
436
    )
437
438
439

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
440
441
    lm_head_weight = language_model.lm_head.weight
    score_weight = lm_head_weight.data[[true_id]].to(
442
        torch.float32
443
    ) - lm_head_weight.data[[false_id]].to(torch.float32)
444
445
446

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
447
    weight_loader(param, score_weight)
448

449
    del language_model.lm_head
450
    loaded_weights.add("score.weight")
451
452
453
454
455

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
456
457
458
    return loaded_weights


459
460
461
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
462

463
    model_config = model.vllm_config.model_config
464
465
466
467
    quant_config = model.vllm_config.quant_config
    text_config = model.config.get_text_config()

    tokens = getattr(text_config, "classifier_from_token", [])
468
469
470
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

471
472
473
474
475
    model.lm_head = ParallelLMHead(
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
476
        # have this attribute, we fall back to get_input_embeddings(), which is used by
477
        # the Transformers modeling backend.
478
479
480
481
        embed_tokens = (
            model.model.embed_tokens
            if hasattr(model.model, "embed_tokens")
            else model.model.get_input_embeddings()
482
        )
483
        model.lm_head = model.lm_head.tie_weights(embed_tokens)
484

485
486
    # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
    loaded_weights = type(model).__mro__[1].load_weights(model, weights)
487

488
    from vllm.tokenizers import get_tokenizer
489
490

    tokenizer = get_tokenizer(
491
492
493
494
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
495
    )
496
497

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
498
499
500
501
502
    score_weight = model.lm_head.weight.data[token_ids]

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
503
504
505
506
507
508
509

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


510
511
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
512
    "no_post_processing": load_weights_no_post_processing,
513
514
515
516
517
518
519
520
521
522
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
523
524
525
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
526

527
528
529
    hf_config = model.vllm_config.model_config.hf_config
    text_config = hf_config.get_text_config()
    method = getattr(hf_config, "method", getattr(text_config, "method", None))
530
531
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)