config.py 28.5 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
4
from math import lcm
5
6
7
from typing import TYPE_CHECKING

from vllm.logger import init_logger
8
from vllm.model_executor.models import ModelRegistry
9
from vllm.platforms import current_platform
10
from vllm.utils.math_utils import cdiv, round_up
11
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
12
from vllm.v1.attention.backends.registry import AttentionBackendEnum
13
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
14
15

if TYPE_CHECKING:
16
    from vllm.config import ModelConfig, VllmConfig
17
18
19
20
21
22
23

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
24
25
26
27
28
        return

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        return
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config

        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
        is_v32 = hasattr(hf_config, "index_topk")
        assert is_v32

        cache_config = vllm_config.cache_config
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")


class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        # Ernie4.5-VL conditionally executes text/vision MoE branches, so
        # fast_moe_cold_start can silently produce incorrect execution order.
        vllm_config.compilation_config.fast_moe_cold_start = False


54
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
55
    @staticmethod
56
57
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        hf_config = model_config.hf_config
58
59
60
        hf_config.is_causal = not hf_config.use_bidirectional_attention


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"

        # Increase the max capture size from 512 to 1024 for performance.
        # NOTE(woosuk): This will increase the number of CUDA graphs
        # from 67 to 83.
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
            compilation_config.max_cudagraph_capture_size = 1024
            logger.info(
                "Overriding max cuda graph capture size to %d for performance.", 1024
            )


84
85
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
86
87
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
88
89
90
91
92
93
94

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
95
96
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
97
98
99
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
100
            "rope_parameters": config.rope_parameters,
101
102
103
        }


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

        attention_config = vllm_config.attention_config
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
            use_cutlass_mla = (
                attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
            )
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
            if (
                current_platform.is_device_capability_family(100)
                and model_config.get_head_size() == 256
                and (
                    attention_config.backend is None
                    or attention_config.backend == AttentionBackendEnum.FLASHINFER
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes

        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
            block_size=-1,  # block_size doesn't matter for mamba page size
        ).page_size_bytes

        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

        if cache_config.mamba_cache_mode == "all":
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
            #      then round up to a multiple of 256 -> 512 tokens
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.

            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )

220
221
222
        # override attention block size if it is too small,
        # even if the user has explicitly set it
        if cache_config.block_size < attn_block_size:
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
                attn_block_size,
            )

        # By default, mamba block size will be set to max_model_len.
        # When enabling prefix caching and using align mamba cache
        # mode, we align mamba block size to the block size as the
        # basic granularity for prefix caching.
        if cache_config.mamba_cache_mode == "align":
            cache_config.mamba_block_size = cache_config.block_size

        # compute new attention page size
        attn_page_size = cache_config.block_size * attn_page_size_1_token

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
                "exactly equal.",
                mamba_padding_pct,
            )


263
264
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
265
266
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
267
268
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
269
270


271
272
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
273
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
274
        config = model_config.hf_config
275
276
277
278
279

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
280
281
282
283
284
285
286
287
288
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

289
290
291
            rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
            config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim

292
293
            config.rotary_kwargs = {
                "head_size": head_dim,
294
                "max_position": max_position,
295
                "rope_parameters": config.rope_parameters,
296
297
298
            }


299
300
301
302
303
304
305
306
307
308
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
        config.num_labels = 1
        pooler_config = model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65


309
310
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
    @staticmethod
311
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
312
        from vllm.config.pooler import SequencePoolingType
313

314
        hf_config = model_config.hf_config
315
316
        hf_config.is_causal = False

317
        pooling_type_map: dict[str, SequencePoolingType] = {
318
319
320
321
322
323
324
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        pooling_type = pooling_type_map.get(hf_config.pooling, None)
        if pooling_type is None:
325
326
327
            raise ValueError(f"pool_type {hf_config.pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type
328
329


330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
    """Config handler for LlamaNemotronVL embedding models."""

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        from vllm.config.pooler import SequencePoolingType

        hf_config = model_config.hf_config

        # Set bidirectional attention on the language model config
        hf_config.is_causal = False
        if hasattr(hf_config, "llm_config"):
            hf_config.llm_config.is_causal = False

        if hasattr(hf_config, "vision_config"):
            hf_config.patch_size = hf_config.vision_config.patch_size

        # Set up pooling type
        pooling_type_map: dict[str, SequencePoolingType] = {
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        # Get pooling type from config (check both top-level and llm_config)
        pooling = getattr(hf_config, "pooling", None)
        if pooling is None and hasattr(hf_config, "llm_config"):
            pooling = getattr(hf_config.llm_config, "pooling", "avg")

        pooling_type = pooling_type_map.get(pooling)
        if pooling_type is None:
            raise ValueError(f"pool_type {pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type


366
367
368
369
370
371
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).
372

373
374
375
376
377
        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
378

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        if cache_config.enable_prefix_caching:
            if cache_config.mamba_cache_mode == "none":
                cache_config.mamba_cache_mode = (
                    "all" if model_config.supports_mamba_prefix_caching else "align"
                )
                logger.warning(
                    "Mamba cache mode is set to '%s' for %s by default "
                    "when prefix caching is enabled",
                    cache_config.mamba_cache_mode,
                    model_config.architecture,
                )
            if (
                cache_config.mamba_cache_mode == "all"
                and not model_config.supports_mamba_prefix_caching
            ):
                cache_config.mamba_cache_mode = "align"
                logger.warning(
                    "Hybrid or mamba-based model detected without support "
                    "for prefix caching with Mamba cache 'all' mode: "
                    "falling back to 'align' mode."
                )
            if cache_config.mamba_cache_mode == "align":
                assert vllm_config.scheduler_config.enable_chunked_prefill, (
                    "Chunked prefill is required for mamba cache mode 'align'."
                )
            logger.info(
                "Warning: Prefix caching in Mamba cache '%s' "
                "mode is currently enabled. "
                "Its support for Mamba layers is experimental. "
                "Please report any issues you may observe.",
                cache_config.mamba_cache_mode,
            )
            # By default, mamba block size will be set to max_model_len (see
            # below). When enabling prefix caching, we align mamba block size
            # to the block size as the basic granularity for prefix caching.
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = cache_config.block_size
        else:
            if cache_config.mamba_cache_mode != "none":
                cache_config.mamba_cache_mode = "none"
                logger.warning(
                    "Mamba cache mode is set to 'none' when prefix caching is disabled"
                )
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = model_config.max_model_len


class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        mm_config = model_config.multimodal_config
        if mm_config is not None:
            video_kwargs = mm_config.media_io_kwargs.setdefault("video", {})
            video_kwargs.setdefault("video_backend", "nemotron_vl")


class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
479
480
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer
481
482
483
484
        model_config.model_arch_config.hidden_size = config.hidden_size
        model_config.model_arch_config.total_num_hidden_layers = (
            config.num_hidden_layers
        )
485
486
487

        head_dim = config.hidden_size // config.num_attention_heads
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
488

489
490
491
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": max_trained_positions,
492
            "rope_parameters": config.rope_parameters,
493
494
495
496
497
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
498
        # The context extension uses vllm style rope_theta and rope_parameters.
499
        # See #17785 #18755
500
        if (
501
502
            not model_config.hf_overrides
            and model_config.original_max_model_len is None
503
        ):
504
505
506
507
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
508
509
            max_model_len_before = model_config.max_model_len
            max_model_len = min(model_config.max_model_len, max_trained_positions)
510

511
512
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
513
            )
514
515
516
517
518
519

            if model_config.max_model_len != max_model_len_before:
                logger.warning(
                    "Nomic context extension is disabled. "
                    "Changing max_model_len from %s to %s. "
                    "To enable context extension, see: "
520
                    "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
521
522
523
                    max_model_len_before,
                    model_config.max_model_len,
                )
524
525
526
527
528
529
530
531
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
532
                    "max_model_len", model_config.max_model_len
533
                )
534
535
536
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
537
                max_model_len = model_config.max_model_len
538
539
540
541
542

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
543
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
544

545
546
547
548
549
550
            # Update the cached derived_max_model_len to enforce the limit
            model_config.model_arch_config.derived_max_model_len_and_key = (
                float(max_trained_positions),
                "max_position_embeddings",
            )

551
552
553
554
555
556
            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

557
558
559
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
            )
560
561


562
563
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
564
565
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
566
567
568
569
570
571
572

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
573
574
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
575

576
577
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
578
579


580
581
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
582
583
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
584

585
586
587
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
588
589
590
591
592

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
593
594
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
595
            "https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
596
        )
597
598
599
        text_config = config.get_text_config()
        text_config.method = "from_2_way_softmax"
        text_config.classifier_from_token = tokens
600
601


602
603
604
605
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
    pass


606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config's
        mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
        different value.
        """
        cache_config = vllm_config.cache_config
        hf_text_config = vllm_config.model_config.hf_text_config
        mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
        if cache_config.mamba_ssm_cache_dtype == "auto":
            if mamba_ssm_dtype is not None:
                cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
        elif (
            mamba_ssm_dtype is not None
            and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
        ):
            logger.warning(
                "Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
                "but --mamba-ssm-cache-dtype='%s' was passed. "
                "Using the user-specified value.",
                mamba_ssm_dtype,
                cache_config.mamba_ssm_cache_dtype,
            )


633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
            "rope_parameters": config.rope_parameters,
        }


chengchengpei's avatar
chengchengpei committed
653
654
655
656
657
658
659
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        model_config.hf_config.is_causal = False
        model_config.hf_config.embedding_size = model_config.hf_config.num_labels


660
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
661
662
663
664
665
666
    "ColBERTJinaRobertaModel": JinaRobertaModelConfig,
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
    "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig,  # noqa: E501
    "FalconMambaForCausalLM": MambaModelConfig,
    "Gemma3TextModel": Gemma3TextModelConfig,
    "GptOssForCausalLM": GptOssForCausalLMConfig,
667
    "GteModel": SnowflakeGteNewModelConfig,
668
    "GteNewForSequenceClassification": GteNewModelConfig,
669
670
671
    "GteNewModel": GteNewModelConfig,
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
672
673
    "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
    "LlamaBidirectionalModel": LlamaBidirectionalConfig,
674
    "LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
675
676
677
678
679
680
    "LlamaNemotronVLModel": LlamaNemotronVLConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
    "MambaForCausalLM": MambaModelConfig,
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
    "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
    "NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
681
    "NomicBertModel": NomicBertModelConfig,
682
683
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
684
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
685
    "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
686
687
    "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
    "Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
chengchengpei's avatar
chengchengpei committed
688
    "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
689
    "XLMRobertaModel": JinaRobertaModelConfig,
690
}