config.py 21.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
4
from math import lcm
5
6
from typing import TYPE_CHECKING

7
from vllm.attention.backends.registry import AttentionBackendEnum
8
from vllm.logger import init_logger
9
from vllm.model_executor.models import ModelRegistry
10
from vllm.platforms import current_platform
11
from vllm.utils.math_utils import cdiv, round_up
12
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
13
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
14
15
16
17
18
19
20
21
22
23
24
25
26

if TYPE_CHECKING:
    from vllm.config import VllmConfig

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        raise NotImplementedError


27
28
29
30
31
32
33
class Gemma3TextModelConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config
        hf_config.is_causal = not hf_config.use_bidirectional_attention


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
49
            "rope_parameters": config.rope_parameters,
50
51
52
        }


53
54
55
56
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config
57
58
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
59
60


61
62
63
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
64
65
        model_config = vllm_config.model_config
        config = model_config.hf_config
66
67
68
69
70

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
71
72
73
74
75
76
77
78
79
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

80
81
82
            config.rotary_kwargs = {
                "head_size": head_dim,
                "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
83
                "max_position": max_position,
84
                "rope_parameters": config.rope_parameters,
85
86
87
88
89
90
91
92
93
94
            }


class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
95
96
97
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )
98
99
100
101
102
103

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

104
        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
105
106
107
108
109
110
111
112
113
114
115
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
116
        rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
117
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
118

119
120
121
122
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": rotary_emb_dim,
            "max_position": max_trained_positions,
123
            "rope_parameters": config.rope_parameters,
124
125
126
127
128
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
129
        # The context extension uses vllm style rope_theta and rope_parameters.
130
        # See #17785 #18755
131
132
133
134
        if (
            not vllm_config.model_config.hf_overrides
            and vllm_config.model_config.original_max_model_len is None
        ):
135
136
137
138
139
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
            max_model_len_before = vllm_config.model_config.max_model_len
140
141
142
            max_model_len = min(
                vllm_config.model_config.max_model_len, max_trained_positions
            )
143
144
145
146
147
148
149

            vllm_config.recalculate_max_model_len(max_model_len)
            logger.warning(
                "Nomic context extension is disabled. "
                "Changing max_model_len from %s to %s. "
                "To enable context extension, see: "
                "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
150
151
152
                max_model_len_before,
                vllm_config.model_config.max_model_len,
            )
153
154
155
156
157
158
159
160
161
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            model_config = vllm_config.model_config
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
162
163
                    "max_model_len", vllm_config.model_config.max_model_len
                )
164
165
166
167
168
169
170
171
172
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
                max_model_len = vllm_config.model_config.max_model_len

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
173
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
174
175
176
177
178
179
180
181
182
183

            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

            vllm_config.recalculate_max_model_len(max_model_len)


184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.softmax is None:
            pooler_config.softmax = False


202
203
204
205
206
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

207
208
209
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
210
211
212
213
214

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
215
216
217
218
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
            "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
        )
219
        vllm_config.model_config.hf_config.method = "from_2_way_softmax"
220
221


222
223
224
225
226
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        config.num_labels = 1
227
228
229
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
230
231


232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
247
            "rope_parameters": config.rope_parameters,
248
249
250
        }


251
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
252
253
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
254
255
256
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"
257

258
        # Increase the max capture size from 512 to 1024 for performance.
259
        # NOTE(woosuk): This will increase the number of CUDA graphs
260
        # from 67 to 83.
261
262
263
264
265
266
267
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
268
            compilation_config.max_cudagraph_capture_size = 1024
269
            logger.info(
270
                "Overriding max cuda graph capture size to %d for performance.", 1024
271
            )
272
273


274
275
276
277
278
279
280
281
282
283
284
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
285
        cache_config = vllm_config.cache_config
286

287
        if cache_config.enable_prefix_caching:
288
            if model_config.supports_mamba_prefix_caching:
289
290
                logger.info(
                    "Warning: Prefix caching is currently enabled. "
291
                    "Its support for Mamba layers is experimental. "
292
293
                    "Please report any issues you may observe."
                )
294
295
296
297
298
                # By default, mamba block size will be set to max_model_len (see
                # below). When enabling prefix caching, we align mamba block size
                # to the block size as the basic granularity for prefix caching.
                if cache_config.mamba_block_size is None:
                    cache_config.mamba_block_size = cache_config.block_size
299
            else:
300
301
302
303
                logger.info(
                    "Hybrid or mamba-based model detected without "
                    "support for prefix caching: disabling."
                )
304
305
                cache_config.enable_prefix_caching = False

306
307
308
        if cache_config.mamba_block_size is None:
            cache_config.mamba_block_size = model_config.max_model_len

309
        # TODO(tdoublep): remove once cascade attention is supported
310
311
312
        logger.info(
            "Disabling cascade attention since it is not supported for hybrid models."
        )
313
        model_config.disable_cascade_attn = True
314
315


316
317
318
319
320
321
322
323
324
325
326
327
328
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
329
330
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
331
332
333
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

334
        attention_config = vllm_config.attention_config
335
336
337
338
339
340
341
342
343
344
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
345
346
347
348
349
350
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
351
352
353
            use_cutlass_mla = (
                attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
            )
354
355
356
357
358
359
360
361
362
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
363
364
365
366
            if (
                current_platform.is_device_capability(100)
                and model_config.get_head_size() == 256
                and (
367
368
                    attention_config.backend is None
                    or attention_config.backend == AttentionBackendEnum.FLASHINFER
369
370
371
372
373
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
374
375
376
377
378
379
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
380

381
382
383
384
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
385
386
387
388

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
389
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
390
391
392
            block_size=model_config.max_model_len,
        ).page_size_bytes

393
394
395
396
397
398
        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

399
400
401
402
        if cache_config.enable_prefix_caching:
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

403
            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
404
405
406
407
408
409
410
411
412
413
414
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
            #      then round up to a mulitple of 256 -> 512 tokens
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
415

416
            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
417
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
418
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
419
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
420
421
422
423
424
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

425
426
427
428
429
430
            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )
431
432
433
434

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
435
        if cache_config.block_size is None or cache_config.block_size < attn_block_size:
436
437
438
439
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
440
441
                attn_block_size,
            )
442
443

        # compute new attention page size
444
        attn_page_size = cache_config.block_size * attn_page_size_1_token
445
446
447
448
449
450
451
452

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
453
454
455
456
457
458
459
460
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
461
462
463
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
464
465
466
                "exactly equal.",
                mamba_padding_pct,
            )
467
468


469
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
470
471
472
473
474
475
476
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

477
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
478
        is_v32 = hasattr(hf_config, "index_topk")
479
        assert is_v32
480

481
        # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
482
        cache_config = vllm_config.cache_config
483
        if cache_config.cache_dtype.startswith("fp8"):
484
485
486
487
488
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
489
490


491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


511
512
513
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
514
    "GteNewForSequenceClassification": GteNewModelConfig,
515
    "Gemma3TextModel": Gemma3TextModelConfig,
516
    "NomicBertModel": NomicBertModelConfig,
517
518
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
519
520
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
    "XLMRobertaModel": JinaRobertaModelConfig,
521
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
522
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
523
    "GptOssForCausalLM": GptOssForCausalLMConfig,
524
525
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
526
    "FalconMambaForCausalLM": MambaModelConfig,
527
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
528
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
529
}