config.py 19 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING

6
import vllm.envs as envs
7
from vllm.logger import init_logger
8
9
10
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
11
12

if TYPE_CHECKING:
13

14
15
16
17
18
19
20
21
22
23
24
25
    from vllm.config import VllmConfig

logger = init_logger(__name__)


class VerifyAndUpdateConfig:

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        raise NotImplementedError


26
27
28
29
30
31
32
33
class Gemma3TextModelConfig:

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config
        hf_config.is_causal = not hf_config.use_bidirectional_attention


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class GteNewModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
            "rope_scaling": getattr(config, "rope_scaling", None)
        }


55
56
57
58
59
60
61
62
63
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.activation is None:
            pooler_config.activation = False


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class JinaRobertaModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
            config.rotary_kwargs = {
                "head_size": head_dim,
                "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
                "max_position": config.max_position_embeddings,
                "base": getattr(config, "rope_theta", config.rotary_emb_base),
                "rope_scaling": getattr(config, "rope_scaling", None)
            }


class NomicBertModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
        config.position_embedding_type = getattr(config,
                                                 "position_embedding_type",
                                                 "rope")

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

        assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
                config.qkv_proj_bias)
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
113
        rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": rotary_emb_dim,
            "max_position": max_trained_positions,
            "base": getattr(config, "rope_theta", config.rotary_emb_base),
            "rope_scaling": getattr(config, "rope_scaling", None)
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
        # The context extension uses vllm style rope_theta and rope_scaling.
        # See #17785 #18755
        if (not vllm_config.model_config.hf_overrides
                and vllm_config.model_config.original_max_model_len is None):
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
            max_model_len_before = vllm_config.model_config.max_model_len
            max_model_len = min(vllm_config.model_config.max_model_len,
                                max_trained_positions)

            vllm_config.recalculate_max_model_len(max_model_len)
            logger.warning(
                "Nomic context extension is disabled. "
                "Changing max_model_len from %s to %s. "
                "To enable context extension, see: "
                "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
                max_model_len_before, vllm_config.model_config.max_model_len)
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            model_config = vllm_config.model_config
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
                    "max_model_len", vllm_config.model_config.max_model_len)
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
                max_model_len = vllm_config.model_config.max_model_len

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
            hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]

            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

            vllm_config.recalculate_max_model_len(max_model_len)


175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.softmax is None:
            pooler_config.softmax = False


195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        is_original_qwen3_reranker = getattr(config,
                                             "is_original_qwen3_reranker",
                                             False)

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
        assert tokens is not None and len(tokens) == 2, \
            ("Try loading the original Qwen3 Reranker?, see: "
             "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
212
        vllm_config.model_config.hf_config.method = "from_2_way_softmax"
213
214


215
216
217
218
219
220
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        config.num_labels = 1
221
222
223
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
224
225


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
            "rope_scaling": getattr(config, "rope_scaling", None)
        }


247
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
248
249
250

    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
251
252
253
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"
254

255
        # Increase the max capture size from 512 to 992 for performance.
256
        # NOTE(woosuk): This will increase the number of CUDA graphs
257
        # from 67 to 81.
258
259
260
261
262
        scheduler_config = vllm_config.scheduler_config
        if len(scheduler_config.cuda_graph_sizes) == 1:
            max_capture_size = scheduler_config.cuda_graph_sizes[0]
            # FIXME(woosuk): When using full cuda graph with FA3, the max
            # supported size is 992.
263
            if max_capture_size < 992:
264
265
266
267
                cuda_graph_sizes = [1, 2, 4]
                # Step size 8 for small batch sizes
                cuda_graph_sizes += [i for i in range(8, 256, 8)]
                # Step size 16 for larger batch sizes
268
                cuda_graph_sizes += [i for i in range(256, 993, 16)]
269
270
271
                scheduler_config.cuda_graph_sizes = cuda_graph_sizes
                logger.info(
                    "Overriding max cuda graph capture size to "
272
                    "%d for performance.", 992)
273
274


275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class MambaModelConfig(VerifyAndUpdateConfig):

    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """

        if not envs.VLLM_USE_V1:
            return

        model_config = vllm_config.model_config
291
        cache_config = vllm_config.cache_config
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        # Set mamba block size to max_model_len (this may get
        # override by prefix caching logic later)
        cache_config.mamba_block_size = model_config.max_model_len

        # TODO(@tdoublep) find a better way to do this than whitelist
        MAMBA2_MODELS = [
            "BambaForCausalLM",
            "FalconH1ForCausalLM",
            "GraniteMoeHybridForCausalLM",
            "Mamba2ForCausalLM",
            "NemotronHForCausalLM",
            "Zamba2ForCausalLM",
        ]
        if cache_config.enable_prefix_caching:
            if model_config.architecture in MAMBA2_MODELS:
                logger.info("Warning: Prefix caching is currently enabled. "
                            "Its support for Mamba2 layers is experimental. "
                            "Please report any issues you may observe.")
            else:
                logger.info("Hybrid or mamba-based model detected without "
                            "support for prefix caching: disabling.")
                cache_config.enable_prefix_caching = False

        # TODO(tdoublep): remove once cascade attention is supported
        logger.info("Disabling cascade attention since it is not supported "
                    "for hybrid models.")
        model_config.disable_cascade_attn = True
320
321


322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):

    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """

        if not envs.VLLM_USE_V1:
            return

340
341
342
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

343
344
345
346
347
348
349
350
351
352
353
354
355
356
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
        attn_page_size_1_token = FullAttentionSpec(
            block_size=1,
            num_kv_heads=model_config.get_num_kv_heads(parallel_config),
            head_size=model_config.get_head_size(),
357
            dtype=kv_cache_dtype).page_size_bytes
358

359
360
361
362
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
363
364
365
366

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
367
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
368
369
370
            block_size=model_config.max_model_len,
        ).page_size_bytes

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        if cache_config.enable_prefix_caching:
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

            # mamba SSD kernel uses a chunk_size, e.g. 256
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
            #      then round up to a mulitple of 256 -> 512 tokens
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
            chunk_size = model_config.get_mamba_chunk_size()
            attn_tokens_per_mamba_state = \
                cdiv(mamba_page_size, attn_page_size_1_token)
            attn_block_size = chunk_size * \
                cdiv(attn_tokens_per_mamba_state, chunk_size)
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

            # some attention backends (e.g. FA) only support setting
            # block size to multiple of 16, so let's suggest a value
            # that would work (note: FA is currently not compatible
            # with mamba layers, use FlashInfer instead).
            attn_block_size = 16 * cdiv(mamba_page_size,
                                        16 * attn_page_size_1_token)
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
        if (cache_config.block_size is None
                or cache_config.block_size < attn_block_size):
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
                attn_block_size)

        # compute new attention page size
        attn_page_size = \
            cache_config.block_size * attn_page_size_1_token

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
        if (cache_config.mamba_page_size_padded is None
                or cache_config.mamba_page_size_padded != attn_page_size):
            cache_config.mamba_page_size_padded = (attn_page_size)
            mamba_padding_pct = 100 * (attn_page_size -
                                       mamba_page_size) / mamba_page_size
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
                "exactly equal.", mamba_padding_pct)


437
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
438
439
440
441
442
443
444
445

    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

446
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
447
        is_v32 = hasattr(hf_config, "index_topk")
448
        assert is_v32
449

450
451
452
453
454
455
456
457
458
459
        # For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
        #   "auto")
        cache_config = vllm_config.cache_config
        if cache_config.cache_dtype == "auto" or \
            cache_config.cache_dtype.startswith("fp8"):
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
460
461


462
463
464
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
465
    "GteNewForSequenceClassification": GteNewModelConfig,
466
    "Gemma3TextModel": Gemma3TextModelConfig,
467
    "NomicBertModel": NomicBertModelConfig,
468
469
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
470
471
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
    "XLMRobertaModel": JinaRobertaModelConfig,
472
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
473
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
474
    "GptOssForCausalLM": GptOssForCausalLMConfig,
475
476
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
477
    "FalconMambaForCausalLM": MambaModelConfig,
478
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
479
}