adapters.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import TYPE_CHECKING, Any, TypeVar, cast
6
7
8
9

import torch
import torch.nn as nn

10
from vllm.config import VllmConfig
11
12
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
13
from vllm.model_executor.models.config import VerifyAndUpdateConfig
14
15
16
from vllm.transformers_utils.config import (
    try_get_dense_modules,
)
17
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
18

19
from .interfaces import supports_multimodal
20
from .interfaces_base import VllmModelForPooling, is_pooling_model
21

22
if TYPE_CHECKING:
23
    from vllm.config import ModelConfig, VllmConfig
24
    from vllm.model_executor.layers.pooler import Pooler
25

26
27
_T = TypeVar("_T", bound=type[nn.Module])

28
29
logger = init_logger(__name__)

30
31
32
33
34
35
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
36
37


38
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
39
40
    """Load Sentence-Transformers Dense projection layers."""

41
42
43
    dense_modules = try_get_dense_modules(
        model_config.model, revision=model_config.revision
    )
44

45
46
    if dense_modules is None:
        return
47

48
    try:
49
        layers = []
50
51
        for layer_config in dense_modules:
            folder = layer_config["folder"]
52
            linear = nn.Linear(
53
54
                layer_config["in_features"],
                layer_config["out_features"],
55
56
57
                bias=layer_config.get("bias", True),
                dtype=model_config.head_dtype,
            )
58
59
60
            if not _load_dense_weights(linear, folder, model_config):
                continue
            layers.append(linear)
61
62
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
63
        return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
64
65
66
67
68
69
    except Exception:
        logger.exception("ST projector loading failed")

    return None


70
71
72
def _load_dense_weights(
    linear: nn.Linear, folder: str, model_config: "ModelConfig"
) -> bool:
73
    """Load weights using vLLM's weight_loader pattern."""
74
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
75
76
77
78
79

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
80
81
82
            file_bytes = get_hf_file_bytes(
                file_path, model_config.model, model_config.revision
            )
83
84
85
86
87
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
88

89
90
91
                state_dict = load_safetensors(file_bytes)
            else:
                import io
92
93
94
95

                state_dict = torch.load(
                    io.BytesIO(file_bytes), map_location="cpu", weights_only=True
                )
96
97
98

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
99
100
101
                    weight_loader = getattr(
                        linear.weight, "weight_loader", default_weight_loader
                    )
102
                    weight_loader(linear.weight, state_dict[weight_key])
103
104
105

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
106
107
108
                        bias_loader = getattr(
                            linear.bias, "weight_loader", default_weight_loader
                        )
109
                        bias_loader(linear.bias, state_dict[bias_key])
110
111
112
113
114
115
116
117
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


118
119
120
121
122
123
124
125
126
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


127
128
129
130
def _create_pooling_model_cls(orig_cls: _T) -> _T:
    # Lazy import
    from vllm.model_executor.layers.logits_processor import LogitsProcessor
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
131

132
    from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
133
134
135
136
137
138
139
140
141
142
143

    class ModelForPooling(orig_cls, VllmModelForPooling):
        is_pooling_model = True

        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
144
145
146
147
148
149
150
151
152
            with no_init_weights(
                self,
                lambda mod: StageMissingLayer("output", mod),
                targets=(LogitsProcessor, ParallelLMHead),
            ):
                super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

            # Used by SEQ_CLS_LOAD_METHODS
            self.vllm_config = vllm_config
153

154
155
156
157
158
159
160
            # If the model already defines a pooler instance, don't overwrite it
            pooler = getattr(self, "pooler", None)
            if not pooler and supports_multimodal(self):
                # Try to get the pooler from the LM backbone
                language_model = self.get_language_model()
                if hasattr(language_model, "pooler"):
                    pooler = language_model.pooler
161

162
163
            if not pooler:
                pooler = self._init_pooler(vllm_config, prefix=prefix)
164

165
            self.pooler = pooler
166

167
        def _init_pooler(
168
169
170
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
171
172
        ) -> "Pooler":
            raise NotImplementedError
173

174
175
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            params_dict = dict(self.named_parameters())
176

177
178
179
            # We support loading from both `*ForCausalLM` and `*Model`
            candidate_prefixes = ["", "model."]
            target_prefix = ""
180

181
182
183
            seen_weights = list[tuple[str, torch.Tensor]]()
            for name, loaded_weight in weights:
                seen_weights.append((name, loaded_weight))
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                try:
                    target_prefix = next(
                        prefix
                        for prefix in candidate_prefixes
                        if prefix + name in params_dict
                    )
                    break
                except StopIteration:
                    # The weight might not exist on the model
                    # (to be handled by AutoWeightsLoader)
                    pass

            if target_prefix:
                target_model = self
                for attr in target_prefix.split("."):
                    if attr:
                        target_model = getattr(self, attr)

                logger.info(
                    "Mapping weights to %s as they are "
                    "relative to this model instead of %s.",
                    target_model._get_name(),
                    self._get_name(),
208
                )
209

210
211
212
213
            mapped_weights = (
                (target_prefix + name, weight)
                for name, weight in (*seen_weights, *weights)
            )
214

215
            def default_load_weights(weights):
216
                loader = AutoWeightsLoader(self)
217
                return loader.load_weights(weights)
218

219
220
221
            load_weights = getattr(super(), "load_weights", default_load_weights)
            return load_weights(mapped_weights)

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
241
    from vllm.model_executor.layers.pooler import DispatchPooler
242
243

    class ModelForEmbedding(_create_pooling_model_cls(cls)):
244
245
246
247
248
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
249
250
251
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

252
            return DispatchPooler.for_embedding(pooler_config)
253

254
    ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
255
256

    return ModelForEmbedding  # type: ignore
257
258


259
def as_seq_cls_model(cls: _T) -> _T:
260
    """
261
    Subclass an existing vLLM model to support classify and score tasks.
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
276
    from vllm.model_executor.layers.linear import ReplicatedLinear
277
    from vllm.model_executor.layers.pooler import DispatchPooler
278
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
279

280
    from .utils import maybe_prefix
281

282
283
284
    class ModelForSequenceClassification(
        _create_pooling_model_cls(cls), SupportsCrossEncoding
    ):
285
286
287
288
289
        def _init_pooler(
            self,
            vllm_config: "VllmConfig",
            prefix: str = "",
        ) -> "Pooler":
290
            text_config = vllm_config.model_config.hf_config.get_text_config()
291
            model_config = vllm_config.model_config
292
293
            quant_config = vllm_config.quant_config

294
            self.score = ReplicatedLinear(
295
                model_config.get_hidden_size(),
296
                text_config.num_labels,
297
                bias=False,
298
                params_dtype=vllm_config.model_config.head_dtype,
299
                quant_config=quant_config,
300
                return_bias=False,
301
302
303
304
305
306
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

307
            return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
308

309
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
310
311
312
313
314
315
316
317
            hf_config = self.config
            text_config = hf_config.get_text_config()
            tokens = getattr(
                hf_config,
                "classifier_from_token",
                getattr(text_config, "classifier_from_token", None),
            )
            method = getattr(hf_config, "method", getattr(text_config, "method", None))
318

319
320
321
322
323
324
325
326
327
328
329
330
            def auto_set_score_bias(weights):
                for name, weight in weights:
                    if name == "score.bias":
                        device = self.score.weight.device
                        dtype = self.score.weight.dtype
                        bias = weight.to(device).to(dtype)
                        self.score.bias = torch.nn.Parameter(bias)
                        self.score.skip_bias_add = False
                    else:
                        yield name, weight

            weights = auto_set_score_bias(weights)
331
332
333
334
335
336
337
            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

338
339
340
    ModelForSequenceClassification.__name__ = _get_pooling_model_name(
        cls.__name__, "ForSequenceClassification"
    )
341

342
    return ModelForSequenceClassification  # type: ignore
343
344


345
346
347
class SequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
348
349
350
351
352
353
354
355
        hf_config = vllm_config.model_config.hf_config
        text_config = hf_config.get_text_config()
        method = getattr(hf_config, "method", getattr(text_config, "method", None))
        tokens = getattr(
            hf_config,
            "classifier_from_token",
            getattr(text_config, "classifier_from_token", None),
        )
356
357
358
359
360
361
362
363
364

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
365
            hf_config.num_labels = 1
366
            text_config.num_labels = 1
367
        else:
368
            hf_config.num_labels = len(tokens)
369
            text_config.num_labels = len(tokens)
370

371
372
373
        # `llm as reranker` defaults to not using separating token.
        use_sep_token = getattr(text_config, "use_sep_token", False)
        text_config.use_sep_token = use_sep_token
374

375
376

def load_weights_using_from_2_way_softmax(
377
378
    model, weights: Iterable[tuple[str, torch.Tensor]]
):
379
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
380
381
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
382

383
    model_config = model.vllm_config.model_config
384
    quant_config = model.vllm_config.quant_config
385
386
    hf_config = model.config
    text_config = hf_config.get_text_config()
387

388
389
390
391
392
    tokens = getattr(
        hf_config,
        "classifier_from_token",
        getattr(text_config, "classifier_from_token", []),
    )
393
394
395
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

396
397
398
399
    language_model = (
        model.get_language_model() if hasattr(model, "get_language_model") else model
    )
    language_model.lm_head = ParallelLMHead(
400
401
402
403
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
404
        # have this attribute, we fall back to get_input_embeddings(), which is used by
405
        # the Transformers modeling backend.
406
        text_backbone = language_model.model
407
        embed_tokens = (
408
409
410
            text_backbone.embed_tokens
            if hasattr(text_backbone, "embed_tokens")
            else text_backbone.get_input_embeddings()
411
        )
412
        language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
413

414
415
416
417
418
    # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
    # function, so we need use this hacky method to obtain it.
    pooling_model_cls = next(
        x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
    )
419
    loaded_weights = pooling_model_cls.load_weights(model, weights)
420

421
    from vllm.tokenizers import get_tokenizer
422
423

    tokenizer = get_tokenizer(
424
425
426
427
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
428
    )
429
430
431

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
432
433
    lm_head_weight = language_model.lm_head.weight
    score_weight = lm_head_weight.data[[true_id]].to(
434
        torch.float32
435
    ) - lm_head_weight.data[[false_id]].to(torch.float32)
436
437
438

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
439
    weight_loader(param, score_weight)
440

441
    del language_model.lm_head
442
    loaded_weights.add("score.weight")
443
444
445
446
447

    lm_head_name = "lm_head.weight"
    if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
        lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
    loaded_weights.discard(lm_head_name)
448
449
450
    return loaded_weights


451
452
453
def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
    from vllm.model_executor.model_loader.weight_utils import default_weight_loader
454

455
    model_config = model.vllm_config.model_config
456
457
458
459
    quant_config = model.vllm_config.quant_config
    text_config = model.config.get_text_config()

    tokens = getattr(text_config, "classifier_from_token", [])
460
461
462
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

463
464
465
466
467
    model.lm_head = ParallelLMHead(
        text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
    )
    if text_config.tie_word_embeddings:
        # embed_tokens is the assumed name for input embeddings. If the model does not
468
        # have this attribute, we fall back to get_input_embeddings(), which is used by
469
        # the Transformers modeling backend.
470
471
472
473
        embed_tokens = (
            model.model.embed_tokens
            if hasattr(model.model, "embed_tokens")
            else model.model.get_input_embeddings()
474
        )
475
        model.lm_head = model.lm_head.tie_weights(embed_tokens)
476

477
478
    # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
    loaded_weights = type(model).__mro__[1].load_weights(model, weights)
479

480
    from vllm.tokenizers import get_tokenizer
481
482

    tokenizer = get_tokenizer(
483
484
485
486
        model_config.tokenizer,
        revision=model_config.tokenizer_revision,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
487
    )
488
489

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
490
491
492
493
494
    score_weight = model.lm_head.weight.data[token_ids]

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
495
496
497
498
499
500
501

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


502
503
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
504
    "no_post_processing": load_weights_no_post_processing,
505
506
507
508
509
510
511
512
513
514
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
515
516
517
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
518

519
520
521
    hf_config = model.vllm_config.model_config.hf_config
    text_config = hf_config.get_text_config()
    method = getattr(hf_config, "method", getattr(text_config, "method", None))
522
523
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)