adapters.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import ast
import inspect
6
from collections.abc import Iterable
7
from typing import TYPE_CHECKING, Any, TypeVar, cast
8
9
10
11

import torch
import torch.nn as nn

12
from vllm.config import VllmConfig
13
14
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
15
from vllm.model_executor.models.config import VerifyAndUpdateConfig
16
17
18
19
from vllm.transformers_utils.config import (
    get_hf_file_bytes,
    try_get_dense_modules,
)
20

21
from .interfaces_base import VllmModelForPooling, is_pooling_model
22

23
if TYPE_CHECKING:
24
    from vllm.config import ModelConfig, VllmConfig
25

26
27
_T = TypeVar("_T", bound=type[nn.Module])

28
29
logger = init_logger(__name__)

30
31
32
33
34
35
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
36
37


38
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
39
40
    """Load Sentence-Transformers Dense projection layers."""

41
42
43
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
44

45
46
    if dense_modules is None:
        return
47

48
    try:
49
        layers = []
50
51
        for layer_config in dense_modules:
            folder = layer_config["folder"]
52
            linear = nn.Linear(
53
54
                layer_config["in_features"],
                layer_config["out_features"],
55
56
57
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
58
59
60
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
61
62
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
63
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
64
65
66
67
68
69
    except Exception:
        logger.exception("ST projector loading failed")

    return None


70
71
72
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
73
    """Load weights using vLLM's weight_loader pattern."""
74
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
75
76
77
78
79

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
80
81
82
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
83
84
85
86
87
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
88

89
90
91
                state_dict = load_safetensors(file_bytes)
            else:
                import io
92
93
94
95

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
96
97
98

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
99
100
101
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
102
                    weight_loader(linear.weight, state_dict[weight_key])
103
104
105

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
106
107
108
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
109
                        bias_loader(linear.bias, state_dict[bias_key])
110
111
112
113
114
115
116
117
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


118
119
120
121
122
123
124
125
126
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
    class CallVisitor(ast.NodeVisitor):
        def __init__(self):
            self.calls = []

        def visit_Call(self, node):
            if isinstance(node.func, ast.Name):
                self.calls.append(node.func.id)
            self.generic_visit(node)

    visitor = CallVisitor()
    visitor.visit(ast.parse(inspect.getsource(orig_cls)))
    if "init_vllm_registered_model" not in visitor.calls:
        return None

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            self.pooler = self.get_language_model().pooler

    return ModelForPooling  # type: ignore


159
def _create_pooling_model_cls(orig_cls: _T) -> _T:
160
161
162
    # Lazy import
    from .utils import AutoWeightsLoader, WeightsMapper

163
    class ModelForPooling(orig_cls, VllmModelForPooling):
164
165
        is_pooling_model = True

166
167
168
169
170
171
172
173
174
        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

175
176
            self.vllm_config = vllm_config

177
            # These are not used in pooling models
178
179
180
181
            for attr in ("lm_head", "logits_processor"):
                if hasattr(self, attr):
                    delattr(self, attr)

182
            # If the model already defines a pooler instance, don't overwrite it
183
            if not getattr(self, "pooler", None):
184
185
186
                self._init_pooler(vllm_config, prefix=prefix)

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
187
            raise NotImplementedError
188
189
190
191
192

        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            # TODO: Support uninitialized params tracking

            # We have deleted this attribute, so don't load it
193
194
195
196
197
            weights = (
                (name, data)
                for name, data in weights
                if not name.startswith("lm_head.")
            )
198
199
200
201
202
203
204
205

            # If `*ForCausalLM` defines `load_weights` on the inner model
            # and there are no other inner modules with parameters,
            # we support loading from both `*Model` and `*ForCausalLM`
            if hasattr(self, "model") and hasattr(self.model, "load_weights"):
                # Whether only `self.model` contains parameters
                model_is_only_param = all(
                    name == "model" or next(child.parameters(), None) is None
206
207
                    for name, child in self.named_children()
                )
208
209
210
211
212

                if model_is_only_param:
                    mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
                    weights = mapper.apply(weights)

213
214
215
                    loaded_params = self.model.load_weights(weights)
                    loaded_params = {f"model.{name}" for name in loaded_params}
                    return loaded_params
216
217

            # For most other models
218
            if hasattr(orig_cls, "load_weights"):
219
                return orig_cls.load_weights(self, weights)  # type: ignore
220
221
222
            # Fallback
            else:
                loader = AutoWeightsLoader(self)
223
                return loader.load_weights(weights)
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
244
245
246
247
248
249
250
251
252
253
254
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "embed": Pooler.for_embed(pooler_config),
255
256
                },
            )
257

258
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
259
260

    return ModelForEmbedding  # type: ignore
261
262


263
def as_seq_cls_model(cls: _T) -> _T:
264
    """
265
    Subclass an existing vLLM model to support classify and score tasks.
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
280
    from vllm.model_executor.layers.linear import ReplicatedLinear
281
282
283
284
285
286
287
    from vllm.model_executor.layers.pooler import (
        ClassifierPooler,
        DispatchPooler,
        Pooler,
        PoolingMethod,
        PoolingType,
    )
288
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
289
290
    from vllm.sequence import IntermediateTensors

291
    from .utils import maybe_prefix
292

293
294
295
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
296
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
297
            config = vllm_config.model_config.hf_config
298
            model_config = vllm_config.model_config
299
300
            quant_config = vllm_config.quant_config

301
            self.score = ReplicatedLinear(
302
                model_config.hidden_size,
303
304
305
306
307
308
309
310
311
312
                config.num_labels,
                bias=False,
                params_dtype=torch.float32,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

313
            pooling_type_str = pooler_config.pooling_type
314
315
            assert pooling_type_str is not None
            pooling_type = PoolingType[pooling_type_str]
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "classify": ClassifierPooler(
                        pooling=PoolingMethod.from_pooling_type(pooling_type),
                        classifier=self._classifier,
                        act_fn=ClassifierPooler.act_fn_for_seq_cls(
                            vllm_config.model_config
                        ),
                    ),
                    "score": ClassifierPooler(
                        pooling=PoolingMethod.from_pooling_type(pooling_type),
                        classifier=self._classifier,
                        act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                            vllm_config.model_config
                        ),
                    ),
                }
            )
336
337
338
339

        def _classifier(self, x: torch.Tensor):
            x, _ = self.score(x.float())
            return x
340
341
342
343
344

        def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
345
346
            intermediate_tensors: IntermediateTensors | None = None,
            inputs_embeds: torch.Tensor | None = None,
347
        ) -> torch.Tensor:
348
349
350
            return super().forward(
                input_ids, positions, intermediate_tensors, inputs_embeds
            )
351

352
353
354
355
356
357
358
359
360
361
362
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            tokens = getattr(self.config, "classifier_from_token", None)
            method = getattr(self.config, "method", None)

            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

363
364
365
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
366

367
    return ModelForSequenceClassification  # type: ignore
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384


def as_reward_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support reward modeling.

    By default, we return the hidden states of each token directly.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing reward models
    if is_pooling_model(cls):
        return cls

    # Lazy import
385
386
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

387
388
389
    from .interfaces_base import default_pooling_type

    @default_pooling_type("ALL")
390
391
392
393
394
395
    class ModelForReward(_create_pooling_model_cls(cls)):
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
396
397
                {"encode": Pooler.for_encode(pooler_config)},
            )
398

399
    ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
400
401

    return ModelForReward  # type: ignore
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422


class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        method = getattr(config, "method", None)
        tokens = getattr(config, "classifier_from_token", None)

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
            config.num_labels = 1
        else:
            config.num_labels = len(tokens)

423
424
425
426
        # `llm as reranker` defaults to not using pad_token
        use_pad_token = getattr(config, "use_pad_token", False)
        config.use_pad_token = use_pad_token

427
428

def load_weights_using_from_2_way_softmax(
429
430
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
431
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
432
433
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
434
435
436
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
437

438
439
440
441
442
443
444
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
445
        quant_config = model.vllm_config.quant_config
446
447
448
        model.lm_head = ParallelLMHead(
            model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
        )
449
450
451
452
453

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
454
455
456
457
458
459
460

    tokenizer = get_tokenizer(
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
    )
461
462
463

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
464
    score_weight = model.lm_head.weight.data[[true_id]].to(
465
466
        torch.float32
    ) - model.lm_head.weight.data[[false_id]].to(torch.float32)
467
468
469

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
470
    weight_loader(param, score_weight)
471
472
473
474
475
476
477

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


478
479
480
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
481
482
483
484
485
486
487
488
489
490
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
491
        quant_config = model.vllm_config.quant_config
492
493
494
        model.lm_head = ParallelLMHead(
            model.config.vocab_size, model.config.hidden_size, quant_config=quant_config
        )
495
496
497
498
499

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
500
501
502
503
504
505
506

    tokenizer = get_tokenizer(
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
    )
507
508

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
509
510
511
512
513
    score_weight = model.lm_head.weight.data[token_ids]

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
514
515
516
517
518
519
520

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


521
522
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
523
    "no_post_processing": load_weights_no_post_processing,
524
525
526
527
528
529
530
531
532
533
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
534
535
536
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
537
538
539
540
541

    config = model.vllm_config.model_config.hf_config
    method = getattr(config, "method", None)
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)