config.py 20.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
4
from math import lcm
5
6
from typing import TYPE_CHECKING

7
import vllm.envs as envs
8
from vllm.logger import init_logger
9
from vllm.model_executor.models import ModelRegistry
10
from vllm.platforms import current_platform
11
from vllm.utils.math_utils import cdiv, round_up
12
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
13
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
14
15
16
17
18
19
20
21
22
23
24
25
26

if TYPE_CHECKING:
    from vllm.config import VllmConfig

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        raise NotImplementedError


27
28
29
30
31
32
33
class Gemma3TextModelConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config
        hf_config.is_causal = not hf_config.use_bidirectional_attention


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
50
            "rope_scaling": getattr(config, "rope_scaling", None),
51
52
53
        }


54
55
56
57
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config
58
59
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
60
61


62
63
64
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
65
66
        model_config = vllm_config.model_config
        config = model_config.hf_config
67
68
69
70
71

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
72
73
74
75
76
77
78
79
80
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

81
82
83
            config.rotary_kwargs = {
                "head_size": head_dim,
                "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
84
                "max_position": max_position,
85
                "base": getattr(config, "rope_theta", config.rotary_emb_base),
86
                "rope_scaling": getattr(config, "rope_scaling", None),
87
88
89
90
91
92
93
94
95
96
            }


class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
97
98
99
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )
100
101
102
103
104
105

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

106
        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
107
108
109
110
111
112
113
114
115
116
117
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
118
        rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
119
120
121
122
123
124
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": rotary_emb_dim,
            "max_position": max_trained_positions,
            "base": getattr(config, "rope_theta", config.rotary_emb_base),
125
            "rope_scaling": getattr(config, "rope_scaling", None),
126
127
128
129
130
131
132
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
        # The context extension uses vllm style rope_theta and rope_scaling.
        # See #17785 #18755
133
134
135
136
        if (
            not vllm_config.model_config.hf_overrides
            and vllm_config.model_config.original_max_model_len is None
        ):
137
138
139
140
141
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
            max_model_len_before = vllm_config.model_config.max_model_len
142
143
144
            max_model_len = min(
                vllm_config.model_config.max_model_len, max_trained_positions
            )
145
146
147
148
149
150
151

            vllm_config.recalculate_max_model_len(max_model_len)
            logger.warning(
                "Nomic context extension is disabled. "
                "Changing max_model_len from %s to %s. "
                "To enable context extension, see: "
                "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
152
153
154
                max_model_len_before,
                vllm_config.model_config.max_model_len,
            )
155
156
157
158
159
160
161
162
163
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            model_config = vllm_config.model_config
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
164
165
                    "max_model_len", vllm_config.model_config.max_model_len
                )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
                max_model_len = vllm_config.model_config.max_model_len

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
            hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]

            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

            vllm_config.recalculate_max_model_len(max_model_len)


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        pooler_config = vllm_config.model_config.pooler_config

        if pooler_config.softmax is None:
            pooler_config.softmax = False


204
205
206
207
208
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

209
210
211
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
212
213
214
215
216

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
217
218
219
220
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
            "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
        )
221
        vllm_config.model_config.hf_config.method = "from_2_way_softmax"
222
223


224
225
226
227
228
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config
        config.num_labels = 1
229
230
231
        pooler_config = vllm_config.model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
232
233


234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        config.rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
250
            "rope_scaling": getattr(config, "rope_scaling", None),
251
252
253
        }


254
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
255
256
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
257
258
259
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"
260

261
        # Increase the max capture size from 512 to 992 for performance.
262
        # NOTE(woosuk): This will increase the number of CUDA graphs
263
        # from 67 to 81.
264
265
266
267
268
269
270
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
271
272
            # FIXME(woosuk): When using full cuda graph with FA3, the max
            # supported size is 992.
273
274
275
276
            compilation_config.max_cudagraph_capture_size = 992
            logger.info(
                "Overriding max cuda graph capture size to %d for performance.", 992
            )
277
278


279
280
281
282
283
284
285
286
287
288
289
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
290
        cache_config = vllm_config.cache_config
291

292
293
        if cache_config.mamba_block_size is None:
            cache_config.mamba_block_size = model_config.max_model_len
294
295

        if cache_config.enable_prefix_caching:
296
            if model_config.supports_mamba_prefix_caching:
297
298
                logger.info(
                    "Warning: Prefix caching is currently enabled. "
299
                    "Its support for Mamba layers is experimental. "
300
301
                    "Please report any issues you may observe."
                )
302
            else:
303
304
305
306
                logger.info(
                    "Hybrid or mamba-based model detected without "
                    "support for prefix caching: disabling."
                )
307
308
309
                cache_config.enable_prefix_caching = False

        # TODO(tdoublep): remove once cascade attention is supported
310
311
312
        logger.info(
            "Disabling cascade attention since it is not supported for hybrid models."
        )
313
        model_config.disable_cascade_attn = True
314
315


316
317
318
319
320
321
322
323
324
325
326
327
328
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
329
330
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
331
332
333
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

334
335
336
337
338
339
340
341
342
343
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
            use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
360
361
362
363
364
365
366
367
368
369
370
            if (
                current_platform.is_device_capability(100)
                and model_config.get_head_size() == 256
                and (
                    envs.VLLM_ATTENTION_BACKEND is None
                    or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
371
372
373
374
375
376
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
377

378
379
380
381
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
382
383
384
385

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
386
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
387
388
389
            block_size=model_config.max_model_len,
        ).page_size_bytes

390
391
392
393
394
395
        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

396
397
398
399
        if cache_config.enable_prefix_caching:
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

400
            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
401
402
403
404
405
406
407
408
409
410
411
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
            #      then round up to a mulitple of 256 -> 512 tokens
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
412

413
            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
414
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
415
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
416
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
417
418
419
420
421
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

422
423
424
425
426
427
            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )
428
429
430
431

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
432
        if cache_config.block_size is None or cache_config.block_size < attn_block_size:
433
434
435
436
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
437
438
                attn_block_size,
            )
439
440

        # compute new attention page size
441
        attn_page_size = cache_config.block_size * attn_page_size_1_token
442
443
444
445
446
447
448
449

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
450
451
452
453
454
455
456
457
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
458
459
460
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
461
462
463
                "exactly equal.",
                mamba_padding_pct,
            )
464
465


466
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
467
468
469
470
471
472
473
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

474
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
475
        is_v32 = hasattr(hf_config, "index_topk")
476
        assert is_v32
477

478
        # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
479
        cache_config = vllm_config.cache_config
480
        if cache_config.cache_dtype.startswith("fp8"):
481
482
483
484
485
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
486
487


488
489
490
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
491
    "GteNewForSequenceClassification": GteNewModelConfig,
492
    "Gemma3TextModel": Gemma3TextModelConfig,
493
    "NomicBertModel": NomicBertModelConfig,
494
495
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
496
497
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
    "XLMRobertaModel": JinaRobertaModelConfig,
498
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
499
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
500
    "GptOssForCausalLM": GptOssForCausalLMConfig,
501
502
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
503
    "FalconMambaForCausalLM": MambaModelConfig,
504
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
505
}