adapters.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
6
7
8
9

import torch
import torch.nn as nn

10
11
from vllm.model_executor.models.config import VerifyAndUpdateConfig

12
from .interfaces_base import VllmModelForPooling, is_pooling_model
13

14
if TYPE_CHECKING:
15
    from vllm.config import VllmConfig
16

17
18
_T = TypeVar("_T", bound=type[nn.Module])

19
20
21
22
23
24
_GENERATE_SUFFIXES = [
    "ForCausalLM",
    "ForConditionalGeneration",
    "ChatModel",
    "LMHeadModel",
]
25
26


27
28
29
30
31
32
33
34
35
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
    model_name = orig_model_name

    for generate_suffix in _GENERATE_SUFFIXES:
        model_name = model_name.removesuffix(generate_suffix)

    return model_name + pooling_suffix


36
def _create_pooling_model_cls(orig_cls: _T) -> _T:
37
38
39
    # Lazy import
    from .utils import AutoWeightsLoader, WeightsMapper

40
    class ModelForPooling(orig_cls, VllmModelForPooling):
41

42
43
        is_pooling_model = True

44
45
46
47
48
49
50
51
52
        def __init__(
            self,
            *,
            vllm_config: "VllmConfig",
            prefix: str = "",
            **kwargs: Any,
        ) -> None:
            super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

53
54
            self.vllm_config = vllm_config

55
            # These are not used in pooling models
56
57
58
59
            for attr in ("lm_head", "logits_processor"):
                if hasattr(self, attr):
                    delattr(self, attr)

60
            # If the model already defines a pooler instance, don't overwrite it
61
            if not getattr(self, "pooler", None):
62
63
64
                self._init_pooler(vllm_config, prefix=prefix)

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
65
            raise NotImplementedError
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            # TODO: Support uninitialized params tracking

            # We have deleted this attribute, so don't load it
            weights = ((name, data) for name, data in weights
                       if not name.startswith("lm_head."))

            # If `*ForCausalLM` defines `load_weights` on the inner model
            # and there are no other inner modules with parameters,
            # we support loading from both `*Model` and `*ForCausalLM`
            if hasattr(self, "model") and hasattr(self.model, "load_weights"):
                # Whether only `self.model` contains parameters
                model_is_only_param = all(
                    name == "model" or next(child.parameters(), None) is None
                    for name, child in self.named_children())

                if model_is_only_param:
                    mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
                    weights = mapper.apply(weights)

87
88
89
                    loaded_params = self.model.load_weights(weights)
                    loaded_params = {f"model.{name}" for name in loaded_params}
                    return loaded_params
90
91

            # For most other models
92
            if hasattr(orig_cls, "load_weights"):
93
                return orig_cls.load_weights(self, weights)  # type: ignore
94
95
96
            # Fallback
            else:
                loader = AutoWeightsLoader(self)
97
                return loader.load_weights(weights)
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    return ModelForPooling  # type: ignore


def as_embedding_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support embeddings.

    By default, the embeddings of the whole prompt are extracted from the
    normalized hidden state corresponding to the last token.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing embedding models
    if is_pooling_model(cls):
        return cls

    # Lazy import
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

    class ModelForEmbedding(_create_pooling_model_cls(cls)):

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
                {
                    "encode": Pooler.for_encode(pooler_config),
                    "embed": Pooler.for_embed(pooler_config),
                }, )

132
133
    ModelForEmbedding.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForEmbedding")
134
135

    return ModelForEmbedding  # type: ignore
136
137


138
def as_seq_cls_model(cls: _T) -> _T:
139
    """
140
    Subclass an existing vLLM model to support classify and score tasks.
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    By default, the class probabilities are extracted from the softmaxed
    hidden state corresponding to the last token.

    Note:
        We assume that the classification head is a single linear layer
        stored as the attribute `score` of the top-level model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing classification models
    if is_pooling_model(cls):
        return cls

    # Lazy import
    from vllm.model_executor.layers.linear import RowParallelLinear
156
    from vllm.model_executor.layers.pooler import (ClassifierPooler,
157
158
                                                   DispatchPooler, Pooler,
                                                   PoolingMethod, PoolingType)
159
    from vllm.model_executor.models.interfaces import SupportsCrossEncoding
160
161
162
163
    from vllm.sequence import IntermediateTensors

    from .utils import maybe_prefix

164
    class ModelForSequenceClassification(_create_pooling_model_cls(cls),
165
                                         SupportsCrossEncoding):
166

167
        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
168
169
170
            config = vllm_config.model_config.hf_config
            quant_config = vllm_config.quant_config

171
172
173
174
175
176
177
178
179
180
181
182
183
            self.score = RowParallelLinear(
                config.hidden_size,
                config.num_labels,
                input_is_parallel=False,
                bias=False,
                params_dtype=torch.float32,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "score"),
            )

            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            pooling_type_str = pooler_config.pooling_type
            pooling_type = (PoolingType.LAST if pooling_type_str is None else
                            PoolingType[pooling_type_str])

            self.pooler = DispatchPooler({
                "encode":
                Pooler.for_encode(pooler_config),
                "classify":
                ClassifierPooler(
                    pooling=PoolingMethod.from_pooling_type(pooling_type),
                    classifier=self._classifier,
                    act_fn=ClassifierPooler.act_fn_for_seq_cls(
                        vllm_config.model_config),
                ),
                "score":
                ClassifierPooler(
                    pooling=PoolingMethod.from_pooling_type(pooling_type),
                    classifier=self._classifier,
                    act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                        vllm_config.model_config),
                ),
            })
206
207
208
209

        def _classifier(self, x: torch.Tensor):
            x, _ = self.score(x.float())
            return x
210
211
212
213
214
215
216
217

        def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
218
219
220
            return super().forward(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)

221
222
223
224
225
226
227
228
229
230
231
        def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
            tokens = getattr(self.config, "classifier_from_token", None)
            method = getattr(self.config, "method", None)

            if tokens is None and method is None:
                return super().load_weights(weights)
            else:
                # Online convert ForCausalLM into
                # ForSequenceClassification model.
                return seq_cls_model_loader(self, weights)

232

233
234
    ModelForSequenceClassification.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForSequenceClassification")
235

236
    return ModelForSequenceClassification  # type: ignore
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253


def as_reward_model(cls: _T) -> _T:
    """
    Subclass an existing vLLM model to support reward modeling.

    By default, we return the hidden states of each token directly.

    Note:
        We assume that no extra layers are added to the original model;
        please implement your own model if this is not the case.
    """
    # Avoid modifying existing reward models
    if is_pooling_model(cls):
        return cls

    # Lazy import
254
255
256
257
258
259
260
261
262
263
    from vllm.model_executor.layers.pooler import DispatchPooler, Pooler

    class ModelForReward(_create_pooling_model_cls(cls)):

        def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
            pooler_config = vllm_config.model_config.pooler_config
            assert pooler_config is not None

            self.pooler = DispatchPooler(
                {"encode": Pooler.for_encode(pooler_config)}, )
264
265
266
267
268

    ModelForReward.__name__ = \
        _get_pooling_model_name(cls.__name__, "ForReward")

    return ModelForReward  # type: ignore
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290


class SequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        method = getattr(config, "method", None)
        tokens = getattr(config, "classifier_from_token", None)

        if method is None:
            return

        assert tokens is not None
        assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"

        if method == "from_2_way_softmax":
            assert len(tokens) == 2
            config.num_labels = 1
        else:
            config.num_labels = len(tokens)

291
292
293
294
        # `llm as reranker` defaults to not using pad_token
        use_pad_token = getattr(config, "use_pad_token", False)
        config.use_pad_token = use_pad_token

295
296
297
298
299
300

def load_weights_using_from_2_way_softmax(
        model, weights: Iterable[tuple[str, torch.Tensor]]):
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        ParallelLMHead)
301
302
    from vllm.model_executor.model_loader.weight_utils import (
        default_weight_loader)
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) == 2

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
        model.lm_head = ParallelLMHead(model.config.vocab_size,
                                       model.config.hidden_size,
                                       quant_config=model.quant_config)

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
    tokenizer = get_tokenizer(model_config.tokenizer,
                              revision=model_config.tokenizer_revision,
                              tokenizer_mode=model_config.tokenizer_mode,
                              trust_remote_code=model_config.trust_remote_code)

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])
328
    score_weight = model.lm_head.weight.data[[true_id]].to(
329
        torch.float32) - model.lm_head.weight.data[[false_id]].to(
330
            torch.float32)
331
332
333

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
334
    weight_loader(param, score_weight)
335
336
337
338
339
340
341

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


342
343
344
345
346
def load_weights_no_post_processing(model,
                                    weights: Iterable[tuple[str,
                                                            torch.Tensor]]):
    from vllm.model_executor.layers.vocab_parallel_embedding import (
        ParallelLMHead)
347
348
    from vllm.model_executor.model_loader.weight_utils import (
        default_weight_loader)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    from vllm.model_executor.models.utils import AutoWeightsLoader

    model_config = model.vllm_config.model_config
    tokens = getattr(model.config, "classifier_from_token", [])
    tokens = cast(list[int], tokens)
    assert len(tokens) > 0

    if model.config.tie_word_embeddings:
        model.lm_head = model.model.embed_tokens
    else:
        model.lm_head = ParallelLMHead(model.config.vocab_size,
                                       model.config.hidden_size,
                                       quant_config=model.quant_config)

    loader = AutoWeightsLoader(model)
    loaded_weights = loader.load_weights(weights)

    from vllm.transformers_utils.tokenizer import get_tokenizer
    tokenizer = get_tokenizer(model_config.tokenizer,
                              revision=model_config.tokenizer_revision,
                              tokenizer_mode=model_config.tokenizer_mode,
                              trust_remote_code=model_config.trust_remote_code)

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
373
374
375
376
377
    score_weight = model.lm_head.weight.data[token_ids]

    param = model.score.weight
    weight_loader = getattr(param, "weight_loader", default_weight_loader)
    weight_loader(param, score_weight)
378
379
380
381
382
383
384

    del model.lm_head
    loaded_weights.add("score.weight")
    loaded_weights.discard("lm_head.weight")
    return loaded_weights


385
386
SEQ_CLS_LOAD_METHODS = {
    "from_2_way_softmax": load_weights_using_from_2_way_softmax,
387
    "no_post_processing": load_weights_no_post_processing,
388
389
390
391
392
393
394
395
396
397
}


def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
    # Online convert ForCausalLM into ForSequenceClassification model.
    # - from_2_way_softmax:
    #   - Qwen3ForCausalLM
    #     - Qwen3-Reranker
    #   - Qwen2ForCausalLM
    #     - mxbai-rerank-v2
398
399
400
    # - no_post_processing:
    #   - GemmaForCausalLM
    #     - bge-reranker-v2-gemma
401
402
403
404
405

    config = model.vllm_config.model_config.hf_config
    method = getattr(config, "method", None)
    assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
    return SEQ_CLS_LOAD_METHODS[method](model, weights)