config.py 16.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING

6
import vllm.envs as envs
7
from vllm.config.compilation import CUDAGraphMode
8
from vllm.logger import init_logger
9
10
11
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
12
13

if TYPE_CHECKING:
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    from vllm.config import VllmConfig

logger = init_logger(__name__)


class VerifyAndUpdateConfig:

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        raise NotImplementedError


class GteNewModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
            "rope_scaling": getattr(config, "rope_scaling", None)
        }


48
49
50
51
52
53
54
55
56
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.activation is None:
            pooler_config.activation = False


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class JinaRobertaModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
            config.rotary_kwargs = {
                "head_size": head_dim,
                "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
                "max_position": config.max_position_embeddings,
                "base": getattr(config, "rope_theta", config.rotary_emb_base),
                "rope_scaling": getattr(config, "rope_scaling", None)
            }


class NomicBertModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
        config.position_embedding_type = getattr(config,
                                                 "position_embedding_type",
                                                 "rope")

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

        assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
                config.qkv_proj_bias)
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
106
        rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": rotary_emb_dim,
            "max_position": max_trained_positions,
            "base": getattr(config, "rope_theta", config.rotary_emb_base),
            "rope_scaling": getattr(config, "rope_scaling", None)
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
        # The context extension uses vllm style rope_theta and rope_scaling.
        # See #17785 #18755
        if (not vllm_config.model_config.hf_overrides
                and vllm_config.model_config.original_max_model_len is None):
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
            max_model_len_before = vllm_config.model_config.max_model_len
            max_model_len = min(vllm_config.model_config.max_model_len,
                                max_trained_positions)

            vllm_config.recalculate_max_model_len(max_model_len)
            logger.warning(
                "Nomic context extension is disabled. "
                "Changing max_model_len from %s to %s. "
                "To enable context extension, see: "
                "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
                max_model_len_before, vllm_config.model_config.max_model_len)
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            model_config = vllm_config.model_config
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
                    "max_model_len", vllm_config.model_config.max_model_len)
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
                max_model_len = vllm_config.model_config.max_model_len

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
            hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]

            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

            vllm_config.recalculate_max_model_len(max_model_len)


168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.softmax is None:
            pooler_config.softmax = False


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        is_original_qwen3_reranker = getattr(config,
                                             "is_original_qwen3_reranker",
                                             False)

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
        assert tokens is not None and len(tokens) == 2, \
            ("Try loading the original Qwen3 Reranker?, see: "
             "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
205
        vllm_config.model_config.hf_config.method = "from_2_way_softmax"
206
207


208
209
210
211
212
213
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        config.num_labels = 1
214
215
216
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
217
218


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
            "rope_scaling": getattr(config, "rope_scaling", None)
        }


240
241
242
243
244
245
246
247
248
249
250
251
252
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config
        config.max_seq_len_to_capture = config.max_model_len
        logger.info(
            "Setting max_seq_len_to_capture to %d "
            "to ensure that CUDA graph capture "
            "covers sequences of length up to max_model_len.",
            config.max_model_len)


253
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
254
255
256
257
258

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        decoding_config = vllm_config.decoding_config
        if decoding_config.reasoning_backend == "":
259
            decoding_config.reasoning_backend = "openai_gptoss"
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        # Increase the max capture size from 512 to 1024 for performance.
        # NOTE(woosuk): This will increase the number of CUDA graphs
        # from 67 to 83.
        scheduler_config = vllm_config.scheduler_config
        if len(scheduler_config.cuda_graph_sizes) == 1:
            max_capture_size = scheduler_config.cuda_graph_sizes[0]
            # FIXME(woosuk): When using full cuda graph with FA3, the max
            # supported size is 992.
            if max_capture_size < 1024:
                cuda_graph_sizes = [1, 2, 4]
                # Step size 8 for small batch sizes
                cuda_graph_sizes += [i for i in range(8, 256, 8)]
                # Step size 16 for larger batch sizes
                cuda_graph_sizes += [i for i in range(256, 1025, 16)]
                scheduler_config.cuda_graph_sizes = cuda_graph_sizes
                logger.info(
                    "Overriding max cuda graph capture size to "
                    "%d for performance.", 1024)


281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class MambaModelConfig(VerifyAndUpdateConfig):

    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """

        if not envs.VLLM_USE_V1:
            return

        model_config = vllm_config.model_config
297
        cache_config = vllm_config.cache_config
298
299
        compilation_config = vllm_config.compilation_config

300
301
302
303
        # TODO(tdoublep): remove once prefix caching is enabled
        cache_config.enable_prefix_caching = False
        logger.info("Hybrid or mamba-based model detected: disabling prefix "
                    "caching since it is not yet supported.")
304
305
306
307
308
309
310
311
312
313
314
315
316
317

        # TODO(tdoublep): remove as full cuda graph support is added
        FCG_NOT_SUPPORTED_MODELS = [
            "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
        ]

        if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
                and compilation_config.cudagraph_mode is None):
            logger.info(
                "Hybrid or mamba-based model detected: setting cudagraph mode "
                "to FULL_AND_PIECEWISE in order to optimize performance.")
            compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE


318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):

    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """

        if not envs.VLLM_USE_V1:
            return

336
337
338
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
        attn_page_size_1_token = FullAttentionSpec(
            block_size=1,
            num_kv_heads=model_config.get_num_kv_heads(parallel_config),
            head_size=model_config.get_head_size(),
            dtype=kv_cache_dtype,
            use_mla=model_config.use_mla).page_size_bytes

356
357
358
359
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
360
361
362
363

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
364
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
            block_size=model_config.max_model_len,
        ).page_size_bytes

        # some attention backends (e.g. FA) only support setting
        # block size to multiple of 16, so let's suggest a value
        # that would work (note: FA is currently not compatible
        # with mamba layers, use FlashInfer instead).
        attn_block_size = 16 * cdiv(mamba_page_size,
                                    16 * attn_page_size_1_token)

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
        if (cache_config.block_size is None
                or cache_config.block_size < attn_block_size):
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
                attn_block_size)

        # compute new attention page size
        attn_page_size = \
            cache_config.block_size * attn_page_size_1_token

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
        if (cache_config.mamba_page_size_padded is None
                or cache_config.mamba_page_size_padded != attn_page_size):
            cache_config.mamba_page_size_padded = (attn_page_size)
            mamba_padding_pct = 100 * (attn_page_size -
                                       mamba_page_size) / mamba_page_size
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
                "exactly equal.", mamba_padding_pct)


408
409
410
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
411
    "GteNewForSequenceClassification": GteNewModelConfig,
412
    "NomicBertModel": NomicBertModelConfig,
413
414
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
415
416
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
    "XLMRobertaModel": JinaRobertaModelConfig,
417
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
418
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
419
    "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
420
    "GptOssForCausalLM": GptOssForCausalLMConfig,
421
422
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
423
    "FalconMambaForCausalLM": MambaModelConfig,
424
}