"vllm/vscode:/vscode.git/clone" did not exist on "9ad0688e436a41e386fa49e81ee344cb59f7d23c"
adapters.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
6
7
8
9

import torch
import torch.nn as nn

10
11
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
12
from vllm.model_executor.models.config import VerifyAndUpdateConfig
13
14
from vllm.transformers_utils.config import (get_hf_file_bytes,
                                            get_hf_file_to_dict)
15

16
from .interfaces_base import VllmModelForPooling, is_pooling_model
17

18
if TYPE_CHECKING:
19
    from vllm.config import ModelConfig, VllmConfig
20

21
22
_T = TypeVar("_T", bound=type[nn.Module])

23
24
logger = init_logger(__name__)

25
26
27
28
29
30
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
    """Load Sentence-Transformers Dense projection layers."""

    try:
        modules = get_hf_file_to_dict("modules.json", model_config.model,
                                      model_config.revision)
        if not modules:
            return None

        if isinstance(modules, dict):
            modules = modules.get("modules", [])

        dense_modules = [
            m for m in modules
            if m.get("type") == "sentence_transformers.models.Dense"
        ]
        if not dense_modules:
            return None

52
53
54
55
56
57
58
59
60
        layers = []
        for module in dense_modules:
            folder = module.get("path", "")

            config_path = f"{folder}/config.json" if folder else "config.json"
            layer_config = get_hf_file_to_dict(config_path, model_config.model,
                                               model_config.revision)
            if not layer_config:
                continue
61

62
63
64
65
            linear = nn.Linear(layer_config.get("in_features", 768),
                               layer_config.get("out_features", 768),
                               bias=layer_config.get("bias", True),
                               dtype=torch.float32)
66

67
68
            if not _load_dense_weights(linear, folder, model_config):
                continue
69

70
            layers.append(linear)
71
72
            if act_name := layer_config.get("activation_function"):
                layers.append(get_act_fn(act_name))
73
        return nn.Sequential(*layers).to(dtype=torch.float32)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    except Exception:
        logger.exception("ST projector loading failed")

    return None


def _load_dense_weights(linear: nn.Linear, folder: str,
                        model_config: "ModelConfig") -> bool:
    """Load weights using vLLM's weight_loader pattern."""
    from vllm.model_executor.model_loader.weight_utils import (
        default_weight_loader)

    for filename in ["model.safetensors", "pytorch_model.bin"]:
        file_path = f"{folder}/{filename}" if folder else filename

        try:
            file_bytes = get_hf_file_bytes(file_path, model_config.model,
                                           model_config.revision)
            if not file_bytes:
                continue

            if filename.endswith(".safetensors"):
                from safetensors.torch import load as load_safetensors
                state_dict = load_safetensors(file_bytes)
            else:
                import io
                state_dict = torch.load(io.BytesIO(file_bytes),
                                        map_location="cpu",
                                        weights_only=True)

            for weight_key in ["weight", "linear.weight", "dense.weight"]:
                if weight_key in state_dict:
                    weight_loader = getattr(linear.weight, "weight_loader",
                                            default_weight_loader)
                    weight_loader(linear.weight,
                                  state_dict[weight_key].to(torch.float32))

                    bias_key = weight_key.replace("weight", "bias")
                    if linear.bias is not None and bias_key in state_dict:
                        bias_loader = getattr(linear.bias, "weight_loader",
                                              default_weight_loader)
                        bias_loader(linear.bias,
                                    state_dict[bias_key].to(torch.float32))
                    return True
        except Exception:
            logger.exception("Failed to load %s", filename)
            continue

    return False


125
126
127
128
129
130
131
132
133
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


134
def _create_pooling_model_cls(orig_cls: _T) -> _T:
135
136
137
    # Lazy import
    from .utils import AutoWeightsLoader, WeightsMapper

138
    class ModelForPooling(orig_cls, VllmModelForPooling):
139

140
141
        is_pooling_model = True

142
143
144
145
146
147
148
149
150
        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

151
152
            self.vllm_config = vllm_config

153
            # These are not used in pooling models
154
155
156
157
            for attr in ("lm_head", "logits_processor"):
                if hasattr(self, attr):
                    delattr(self, attr)

158
            # If the model already defines a pooler instance, don't overwrite it
159
            if not getattr(self, "pooler", None):
160
161
162
                self._init_pooler(vllm_config, prefix=prefix)

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
163
            raise NotImplementedError
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            # TODO: Support uninitialized params tracking

            # We have deleted this attribute, so don't load it
            weights = ((name, data) for name, data in weights
                       if not name.startswith("lm_head."))

            # If `*ForCausalLM` defines `load_weights` on the inner model
            # and there are no other inner modules with parameters,
            # we support loading from both `*Model` and `*ForCausalLM`
            if hasattr(self, "model") and hasattr(self.model, "load_weights"):
                # Whether only `self.model` contains parameters
                model_is_only_param = all(
                    name == "model" or next(child.parameters(), None) is None
                    for name, child in self.named_children())

                if model_is_only_param:
                    mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
                    weights = mapper.apply(weights)

185
186
187
                    loaded_params = self.model.load_weights(weights)
                    loaded_params = {f"model.{name}" for name in loaded_params}
                    return loaded_params
188
189

            # For most other models
190
            if hasattr(orig_cls, "load_weights"):
191
                return orig_cls.load_weights(self, weights)  # type: ignore
192
193
194
            # Fallback
            else:
                loader = AutoWeightsLoader(self)
195
                return loader.load_weights(weights)
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

    class ModelForEmbedding(_create_pooling_model_cls(cls)):

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "embed": Pooler.for_embed(pooler_config),
                }, )

230
231
    ModelForEmbedding.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForEmbedding")
232
233

    return ModelForEmbedding  # type: ignore
234
235


236
def as_seq_cls_model(cls: _T) -> _T:
237
    """
238
    Subclass an existing vLLM model to support classify and score tasks.
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
253
    from vllm.model_executor.layers.linear import ReplicatedLinear
254
    from vllm.model_executor.layers.pooler import (ClassifierPooler,
255
256
                                                   DispatchPooler, Pooler,
                                                   PoolingMethod, PoolingType)
257
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
258
259
260
261
    from vllm.sequence import IntermediateTensors

    from .utils import maybe_prefix

262
    class ModelForSequenceClassification(_create_pooling_model_cls(cls),
263
                                         SupportsCrossEncoding):
264

265
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
266
267
268
            config = vllm_config.model_config.hf_config
            quant_config = vllm_config.quant_config

269
            self.score = ReplicatedLinear(
270
271
272
273
274
275
276
277
278
279
280
                config.hidden_size,
                config.num_labels,
                bias=False,
                params_dtype=torch.float32,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

281
            pooling_type_str = pooler_config.pooling_type
282
283
            assert pooling_type_str is not None
            pooling_type = PoolingType[pooling_type_str]
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

            self.pooler = DispatchPooler({
                "encode":
                Pooler.for_encode(pooler_config),
                "classify":
                ClassifierPooler(
                    pooling=PoolingMethod.from_pooling_type(pooling_type),
                    classifier=self._classifier,
                    act_fn=ClassifierPooler.act_fn_for_seq_cls(
                        vllm_config.model_config),
                ),
                "score":
                ClassifierPooler(
                    pooling=PoolingMethod.from_pooling_type(pooling_type),
                    classifier=self._classifier,
                    act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                        vllm_config.model_config),
                ),
            })
303
304
305
306

        def _classifier(self, x: torch.Tensor):
            x, _ = self.score(x.float())
            return x
307
308
309
310
311
312
313
314

        def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
315
316
317
            return super().forward(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)

318
319
320
321
322
323
324
325
326
327
328
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            tokens = getattr(self.config, "classifier_from_token", None)
            method = getattr(self.config, "method", None)

            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

329

330
331
    ModelForSequenceClassification.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForSequenceClassification")
332

333
    return ModelForSequenceClassification  # type: ignore
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350


def as_reward_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support reward modeling.

    By default, we return the hidden states of each token directly.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing reward models
    if is_pooling_model(cls):
        return cls

    # Lazy import
351
352
353
354
355
356
357
358
359
360
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

    class ModelForReward(_create_pooling_model_cls(cls)):

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
                {"encode": Pooler.for_encode(pooler_config)}, )
361
362
363
364
365

    ModelForReward.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForReward")

    return ModelForReward  # type: ignore
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387


class SequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        method = getattr(config, "method", None)
        tokens = getattr(config, "classifier_from_token", None)

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
            config.num_labels = 1
        else:
            config.num_labels = len(tokens)

388
389
390
391
        # `llm as reranker` defaults to not using pad_token
        use_pad_token = getattr(config, "use_pad_token", False)
        config.use_pad_token = use_pad_token

392
393
394
395
396
397

def load_weights_using_from_2_way_softmax(
        model, weights: Iterable[tuple[str, torch.Tensor]]):
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        ParallelLMHead)
398
399
    from vllm.model_executor.model_loader.weight_utils import (
        default_weight_loader)
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
        model.lm_head = ParallelLMHead(model.config.vocab_size,
                                       model.config.hidden_size,
                                       quant_config=model.quant_config)

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
    tokenizer = get_tokenizer(model_config.tokenizer,
                              revision=model_config.tokenizer_revision,
                              tokenizer_mode=model_config.tokenizer_mode,
                              trust_remote_code=model_config.trust_remote_code)

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
425
    score_weight = model.lm_head.weight.data[[true_id]].to(
426
        torch.float32) - model.lm_head.weight.data[[false_id]].to(
427
            torch.float32)
428
429
430

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
431
    weight_loader(param, score_weight)
432
433
434
435
436
437
438

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


439
440
441
442
443
def load_weights_no_post_processing(model,
                                    weights: Iterable[tuple[str,
                                                            torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        ParallelLMHead)
444
445
    from vllm.model_executor.model_loader.weight_utils import (
        default_weight_loader)
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
        model.lm_head = ParallelLMHead(model.config.vocab_size,
                                       model.config.hidden_size,
                                       quant_config=model.quant_config)

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
    tokenizer = get_tokenizer(model_config.tokenizer,
                              revision=model_config.tokenizer_revision,
                              tokenizer_mode=model_config.tokenizer_mode,
                              trust_remote_code=model_config.trust_remote_code)

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
470
471
472
473
474
    score_weight = model.lm_head.weight.data[token_ids]

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
475
476
477
478
479
480
481

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


482
483
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
484
    "no_post_processing": load_weights_no_post_processing,
485
486
487
488
489
490
491
492
493
494
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
495
496
497
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
498
499
500
501
502

    config = model.vllm_config.model_config.hf_config
    method = getattr(config, "method", None)
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)