config.py 22.5 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
4
from math import lcm
5
6
from typing import TYPE_CHECKING

7
from vllm.attention.backends.registry import AttentionBackendEnum
8
from vllm.logger import init_logger
9
from vllm.model_executor.models import ModelRegistry
10
from vllm.platforms import current_platform
11
from vllm.utils.math_utils import cdiv, round_up
12
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
13
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
14
15

if TYPE_CHECKING:
16
    from vllm.config import ModelConfig, VllmConfig
17
18
19
20
21
22
23

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
24
25
26
27
28
        return

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        return
29
30


31
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
32
    @staticmethod
33
34
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        hf_config = model_config.hf_config
35
36
37
        hf_config.is_causal = not hf_config.use_bidirectional_attention


38
39
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
40
41
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
42
43
44
45
46
47
48

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
49
50
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
51
52
53
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
54
            "rope_parameters": config.rope_parameters,
55
56
57
        }


58
59
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
60
61
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
62
63
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
64
65


66
67
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
68
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
69
        config = model_config.hf_config
70
71
72
73
74

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
75
76
77
78
79
80
81
82
83
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

84
85
86
            rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
            config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim

87
88
            config.rotary_kwargs = {
                "head_size": head_dim,
89
                "max_position": max_position,
90
                "rope_parameters": config.rope_parameters,
91
92
93
            }


94
95
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
    @staticmethod
96
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
97
98
        from vllm.config.pooler import PoolingTypeStr

99
        hf_config = model_config.hf_config
100
101
102
103
104
105
106
107
108
109
110
        hf_config.is_causal = False

        pooling_type_map: dict[str, PoolingTypeStr] = {
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        pooling_type = pooling_type_map.get(hf_config.pooling, None)
        if pooling_type is None:
            raise ValueError(f"pool_type {hf_config.pooling} not supported")
111
        model_config.pooler_config.pooling_type = pooling_type
112
113


114
115
116
117
118
119
120
class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
121
122
123
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )
124
125
126
127
128
129

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

130
        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
131
132
133
134
135
136
137
138
139
140
141
142
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
143

144
145
146
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": max_trained_positions,
147
            "rope_parameters": config.rope_parameters,
148
149
150
151
152
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
153
        # The context extension uses vllm style rope_theta and rope_parameters.
154
        # See #17785 #18755
155
156
157
158
        if (
            not vllm_config.model_config.hf_overrides
            and vllm_config.model_config.original_max_model_len is None
        ):
159
160
161
162
163
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
            max_model_len_before = vllm_config.model_config.max_model_len
164
165
166
            max_model_len = min(
                vllm_config.model_config.max_model_len, max_trained_positions
            )
167
168
169
170
171
172
173

            vllm_config.recalculate_max_model_len(max_model_len)
            logger.warning(
                "Nomic context extension is disabled. "
                "Changing max_model_len from %s to %s. "
                "To enable context extension, see: "
                "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
174
175
176
                max_model_len_before,
                vllm_config.model_config.max_model_len,
            )
177
178
179
180
181
182
183
184
185
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            model_config = vllm_config.model_config
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
186
187
                    "max_model_len", vllm_config.model_config.max_model_len
                )
188
189
190
191
192
193
194
195
196
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
                max_model_len = vllm_config.model_config.max_model_len

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
197
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
198
199
200
201
202
203
204
205
206
207

            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

            vllm_config.recalculate_max_model_len(max_model_len)


208
209
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
210
211
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
212
213
214
215
216
217
218

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
219
220
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
221
222
223
224
225

        if pooler_config.softmax is None:
            pooler_config.softmax = False


226
227
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
228
229
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
230

231
232
233
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
234
235
236
237
238

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
239
240
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
241
            "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
242
        )
243
        model_config.hf_config.method = "from_2_way_softmax"
244
245


246
247
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
248
249
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
250
        config.num_labels = 1
251
        pooler_config = model_config.pooler_config
252
253
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
254
255


256
257
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
258
259
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
260
261
262
263
264
265
266

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
267
268
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
269
270
271
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
272
            "rope_parameters": config.rope_parameters,
273
274
275
        }


276
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
277
278
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
279
280
281
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"
282

283
        # Increase the max capture size from 512 to 1024 for performance.
284
        # NOTE(woosuk): This will increase the number of CUDA graphs
285
        # from 67 to 83.
286
287
288
289
290
291
292
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
293
            compilation_config.max_cudagraph_capture_size = 1024
294
            logger.info(
295
                "Overriding max cuda graph capture size to %d for performance.", 1024
296
            )
297
298


299
300
301
302
303
304
305
306
307
308
309
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
310
        cache_config = vllm_config.cache_config
311

312
        if cache_config.enable_prefix_caching:
313
            if model_config.supports_mamba_prefix_caching:
314
315
                logger.info(
                    "Warning: Prefix caching is currently enabled. "
316
                    "Its support for Mamba layers is experimental. "
317
318
                    "Please report any issues you may observe."
                )
319
320
321
322
323
                # By default, mamba block size will be set to max_model_len (see
                # below). When enabling prefix caching, we align mamba block size
                # to the block size as the basic granularity for prefix caching.
                if cache_config.mamba_block_size is None:
                    cache_config.mamba_block_size = cache_config.block_size
324
            else:
325
326
327
328
                logger.info(
                    "Hybrid or mamba-based model detected without "
                    "support for prefix caching: disabling."
                )
329
330
                cache_config.enable_prefix_caching = False

331
332
333
        if cache_config.mamba_block_size is None:
            cache_config.mamba_block_size = model_config.max_model_len

334

335
336
337
338
339
340
341
342
343
344
345
346
347
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
348
349
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
350
351
352
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

353
        attention_config = vllm_config.attention_config
354
355
356
357
358
359
360
361
362
363
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
364
365
366
367
368
369
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
370
371
372
            use_cutlass_mla = (
                attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
            )
373
374
375
376
377
378
379
380
381
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
382
            if (
383
                current_platform.is_device_capability_family(100)
384
385
                and model_config.get_head_size() == 256
                and (
386
387
                    attention_config.backend is None
                    or attention_config.backend == AttentionBackendEnum.FLASHINFER
388
389
390
391
392
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
393
394
395
396
397
398
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
399

400
401
402
403
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
404
405
406
407

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
408
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
409
410
411
            block_size=model_config.max_model_len,
        ).page_size_bytes

412
413
414
415
416
417
        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

418
419
420
421
        if cache_config.enable_prefix_caching:
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

422
            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
423
424
425
426
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
427
            #      then round up to a multiple of 256 -> 512 tokens
428
429
430
431
432
433
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
434

435
            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
436
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
437
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
438
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
439
440
441
442
443
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

444
445
446
447
448
449
            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )
450
451
452
453

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
454
        if cache_config.block_size is None or cache_config.block_size < attn_block_size:
455
456
457
458
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
459
460
                attn_block_size,
            )
461
462

        # compute new attention page size
463
        attn_page_size = cache_config.block_size * attn_page_size_1_token
464
465
466
467
468
469
470
471

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
472
473
474
475
476
477
478
479
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
480
481
482
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
483
484
485
                "exactly equal.",
                mamba_padding_pct,
            )
486
487


488
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
489
490
491
492
493
494
495
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

496
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
497
        is_v32 = hasattr(hf_config, "index_topk")
498
        assert is_v32
499

500
        # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
501
        cache_config = vllm_config.cache_config
502
        if cache_config.cache_dtype.startswith("fp8"):
503
504
505
506
507
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
508
509


510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


530
531
532
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
533
    "GteNewForSequenceClassification": GteNewModelConfig,
534
    "Gemma3TextModel": Gemma3TextModelConfig,
535
536
    "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
    "LlamaBidirectionalModel": LlamaBidirectionalConfig,
537
    "NomicBertModel": NomicBertModelConfig,
538
539
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
540
541
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
    "XLMRobertaModel": JinaRobertaModelConfig,
542
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
543
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
544
    "GptOssForCausalLM": GptOssForCausalLMConfig,
545
546
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
547
    "FalconMambaForCausalLM": MambaModelConfig,
548
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
549
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
550
}