config.py 28 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
4
from math import lcm
5
6
7
from typing import TYPE_CHECKING

from vllm.logger import init_logger
8
from vllm.model_executor.models import ModelRegistry
9
from vllm.platforms import current_platform
10
from vllm.utils.math_utils import cdiv, round_up
11
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
12
from vllm.v1.attention.backends.registry import AttentionBackendEnum
13
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
14
15

if TYPE_CHECKING:
16
    from vllm.config import ModelConfig, VllmConfig
17
18
19
20
21
22
23

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
24
25
26
27
28
        return

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        return
29
30


31
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
32
    @staticmethod
33
34
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        hf_config = model_config.hf_config
35
36
37
        hf_config.is_causal = not hf_config.use_bidirectional_attention


38
39
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
40
41
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
42
43
44
45
46
47
48

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
49
50
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
51
52
53
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
54
            "rope_parameters": config.rope_parameters,
55
56
57
        }


58
59
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
60
61
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
62
63
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
64
65


66
67
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
68
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
69
        config = model_config.hf_config
70
71
72
73
74

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
75
76
77
78
79
80
81
82
83
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

84
85
86
            rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
            config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim

87
88
            config.rotary_kwargs = {
                "head_size": head_dim,
89
                "max_position": max_position,
90
                "rope_parameters": config.rope_parameters,
91
92
93
            }


94
95
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
    @staticmethod
96
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
97
        from vllm.config.pooler import SequencePoolingType
98

99
        hf_config = model_config.hf_config
100
101
        hf_config.is_causal = False

102
        pooling_type_map: dict[str, SequencePoolingType] = {
103
104
105
106
107
108
109
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        pooling_type = pooling_type_map.get(hf_config.pooling, None)
        if pooling_type is None:
110
111
112
            raise ValueError(f"pool_type {hf_config.pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type
113
114


115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
    """Config handler for LlamaNemotronVL embedding models."""

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        from vllm.config.pooler import SequencePoolingType

        hf_config = model_config.hf_config

        # Set bidirectional attention on the language model config
        hf_config.is_causal = False
        if hasattr(hf_config, "llm_config"):
            hf_config.llm_config.is_causal = False

        if hasattr(hf_config, "vision_config"):
            hf_config.patch_size = hf_config.vision_config.patch_size

        # Set up pooling type
        pooling_type_map: dict[str, SequencePoolingType] = {
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        # Get pooling type from config (check both top-level and llm_config)
        pooling = getattr(hf_config, "pooling", None)
        if pooling is None and hasattr(hf_config, "llm_config"):
            pooling = getattr(hf_config.llm_config, "pooling", "avg")

        pooling_type = pooling_type_map.get(pooling)
        if pooling_type is None:
            raise ValueError(f"pool_type {pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type


151
152
class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
153
154
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
155
156
157

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
158
159
160
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )
161
162
163
164
165
166

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

167
        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
168
169
170
171
172
173
174
175
176
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer
177
178
179
180
        model_config.model_arch_config.hidden_size = config.hidden_size
        model_config.model_arch_config.total_num_hidden_layers = (
            config.num_hidden_layers
        )
181
182
183

        head_dim = config.hidden_size // config.num_attention_heads
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
184

185
186
187
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": max_trained_positions,
188
            "rope_parameters": config.rope_parameters,
189
190
191
192
193
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
194
        # The context extension uses vllm style rope_theta and rope_parameters.
195
        # See #17785 #18755
196
        if (
197
198
            not model_config.hf_overrides
            and model_config.original_max_model_len is None
199
        ):
200
201
202
203
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
204
205
            max_model_len_before = model_config.max_model_len
            max_model_len = min(model_config.max_model_len, max_trained_positions)
206

207
208
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
209
            )
210
211
212
213
214
215
216
217
218
219

            if model_config.max_model_len != max_model_len_before:
                logger.warning(
                    "Nomic context extension is disabled. "
                    "Changing max_model_len from %s to %s. "
                    "To enable context extension, see: "
                    "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
                    max_model_len_before,
                    model_config.max_model_len,
                )
220
221
222
223
224
225
226
227
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
228
                    "max_model_len", model_config.max_model_len
229
                )
230
231
232
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
233
                max_model_len = model_config.max_model_len
234
235
236
237
238

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
239
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
240

241
242
243
244
245
246
            # Update the cached derived_max_model_len to enforce the limit
            model_config.model_arch_config.derived_max_model_len_and_key = (
                float(max_trained_positions),
                "max_position_embeddings",
            )

247
248
249
250
251
252
            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

253
254
255
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
            )
256
257


258
259
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
260
261
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
262
263
264
265
266
267
268

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
269
270
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
271

272
273
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
274
275


276
277
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
278
279
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
280

281
282
283
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
284
285
286
287
288

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
289
290
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
291
            "https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
292
        )
293
294
295
        text_config = config.get_text_config()
        text_config.method = "from_2_way_softmax"
        text_config.classifier_from_token = tokens
296
297


298
299
300
301
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
    pass


302
303
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
304
305
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
306
        config.num_labels = 1
307
        pooler_config = model_config.pooler_config
308
309
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65
310
311


312
313
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
314
315
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
316
317
318
319
320
321
322

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
323
324
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
325
326
327
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
328
            "rope_parameters": config.rope_parameters,
329
330
331
        }


332
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
333
334
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
335
336
337
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"
338

339
        # Increase the max capture size from 512 to 1024 for performance.
340
        # NOTE(woosuk): This will increase the number of CUDA graphs
341
        # from 67 to 83.
342
343
344
345
346
347
348
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
349
            compilation_config.max_cudagraph_capture_size = 1024
350
            logger.info(
351
                "Overriding max cuda graph capture size to %d for performance.", 1024
352
            )
353
354


355
356
357
358
359
360
361
362
363
364
365
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
366
        cache_config = vllm_config.cache_config
367

368
        if cache_config.enable_prefix_caching:
369
370
371
            if cache_config.mamba_cache_mode == "none":
                cache_config.mamba_cache_mode = (
                    "all" if model_config.supports_mamba_prefix_caching else "align"
372
                )
373
374
375
376
377
                logger.warning(
                    "Mamba cache mode is set to '%s' for %s by default "
                    "when prefix caching is enabled",
                    cache_config.mamba_cache_mode,
                    model_config.architecture,
378
                )
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            if (
                cache_config.mamba_cache_mode == "all"
                and not model_config.supports_mamba_prefix_caching
            ):
                cache_config.mamba_cache_mode = "align"
                logger.warning(
                    "Hybrid or mamba-based model detected without support "
                    "for prefix caching with Mamba cache 'all' mode: "
                    "falling back to 'align' mode."
                )
            if cache_config.mamba_cache_mode == "align":
                assert vllm_config.scheduler_config.enable_chunked_prefill, (
                    "Chunked prefill is required for mamba cache mode 'align'."
                )
            logger.info(
                "Warning: Prefix caching in Mamba cache '%s' "
                "mode is currently enabled. "
                "Its support for Mamba layers is experimental. "
                "Please report any issues you may observe.",
                cache_config.mamba_cache_mode,
            )
            # By default, mamba block size will be set to max_model_len (see
            # below). When enabling prefix caching, we align mamba block size
            # to the block size as the basic granularity for prefix caching.
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = cache_config.block_size
        else:
            if cache_config.mamba_cache_mode != "none":
                cache_config.mamba_cache_mode = "none"
                logger.warning(
                    "Mamba cache mode is set to 'none' when prefix caching is disabled"
                )
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = model_config.max_model_len
413

414

415
416
417
418
419
420
421
422
423
424
425
426
427
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
428
429
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
430
431
432
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

433
        attention_config = vllm_config.attention_config
434
435
436
437
438
439
440
441
442
443
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
444
445
446
447
448
449
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
450
451
452
            use_cutlass_mla = (
                attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
            )
453
454
455
456
457
458
459
460
461
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
462
            if (
463
                current_platform.is_device_capability_family(100)
464
465
                and model_config.get_head_size() == 256
                and (
466
467
                    attention_config.backend is None
                    or attention_config.backend == AttentionBackendEnum.FLASHINFER
468
469
470
471
472
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
473
474
475
476
477
478
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
479

480
481
482
483
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )
484
485
486
487

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
488
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
489
            block_size=-1,  # block_size doesn't matter for mamba page size
490
491
        ).page_size_bytes

492
493
494
495
496
497
        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

498
        if cache_config.mamba_cache_mode == "all":
499
500
501
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

502
            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
503
504
505
506
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
507
            #      then round up to a multiple of 256 -> 512 tokens
508
509
510
511
512
513
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
514

515
            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
516
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
517
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
518
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
519
520
521
522
523
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

524
525
526
527
528
529
            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )
530
531
532
533

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
534
        if cache_config.block_size is None or cache_config.block_size < attn_block_size:
535
536
537
538
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
539
540
                attn_block_size,
            )
541

542
543
544
545
546
547
548
        # By default, mamba block size will be set to max_model_len.
        # When enabling prefix caching and using align mamba cache
        # mode, we align mamba block size to the block size as the
        # basic granularity for prefix caching.
        if cache_config.mamba_cache_mode == "align":
            cache_config.mamba_block_size = cache_config.block_size

549
        # compute new attention page size
550
        attn_page_size = cache_config.block_size * attn_page_size_1_token
551
552
553
554
555
556
557
558

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
559
560
561
562
563
564
565
566
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
567
568
569
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
570
571
572
                "exactly equal.",
                mamba_padding_pct,
            )
573
574


575
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
576
577
578
579
580
581
582
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

583
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
584
        is_v32 = hasattr(hf_config, "index_topk")
585
        assert is_v32
586

587
        # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
588
        cache_config = vllm_config.cache_config
589
        if cache_config.cache_dtype.startswith("fp8"):
590
591
592
593
594
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
595
596


597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config's
        mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
        different value.
        """
        cache_config = vllm_config.cache_config
        hf_text_config = vllm_config.model_config.hf_text_config
        mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
        if cache_config.mamba_ssm_cache_dtype == "auto":
            if mamba_ssm_dtype is not None:
                cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
        elif (
            mamba_ssm_dtype is not None
            and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
        ):
            logger.warning(
                "Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
                "but --mamba-ssm-cache-dtype='%s' was passed. "
                "Using the user-specified value.",
                mamba_ssm_dtype,
                cache_config.mamba_ssm_cache_dtype,
            )


chengchengpei's avatar
chengchengpei committed
644
645
646
647
648
649
650
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        model_config.hf_config.is_causal = False
        model_config.hf_config.embedding_size = model_config.hf_config.num_labels


651
652
653
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
    "GteModel": SnowflakeGteNewModelConfig,
    "GteNewModel": GteNewModelConfig,
654
    "GteNewForSequenceClassification": GteNewModelConfig,
655
    "Gemma3TextModel": Gemma3TextModelConfig,
656
657
    "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
    "LlamaBidirectionalModel": LlamaBidirectionalConfig,
658
    "LlamaNemotronVLModel": LlamaNemotronVLConfig,
659
    "NomicBertModel": NomicBertModelConfig,
660
661
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
662
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
663
    "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
664
    "XLMRobertaModel": JinaRobertaModelConfig,
665
    "ColBERTJinaRobertaModel": JinaRobertaModelConfig,
666
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
667
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
668
    "GptOssForCausalLM": GptOssForCausalLMConfig,
669
670
    "MambaForCausalLM": MambaModelConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
671
    "FalconMambaForCausalLM": MambaModelConfig,
672
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
673
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
674
    "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
675
676
    "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
    "Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
chengchengpei's avatar
chengchengpei committed
677
    "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
678
}