config.py 26.8 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING

from vllm.logger import init_logger
7
from vllm.utils.math_utils import round_up
8
9

if TYPE_CHECKING:
10
11
12
    from transformers import PretrainedConfig

    from vllm.config import CacheConfig, ModelConfig, VllmConfig
13

14

15
16
17
18
19
20
logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
21
22
23
24
25
        return

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        return
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config

        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
        is_v32 = hasattr(hf_config, "index_topk")
        assert is_v32

        cache_config = vllm_config.cache_config
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")


class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        # Ernie4.5-VL conditionally executes text/vision MoE branches, so
        # fast_moe_cold_start can silently produce incorrect execution order.
        vllm_config.compilation_config.fast_moe_cold_start = False


51
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
52
    @staticmethod
53
54
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        hf_config = model_config.hf_config
55
56
57
        hf_config.is_causal = not hf_config.use_bidirectional_attention


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
class Gemma4Config(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Force unified attention backend for models with heterogeneous
        head dimensions.

        Some Gemma4 variants use different head dimensions for
        sliding window (head_dim) vs full attention (global_head_dim) layers.
        When global_head_dim > 256, FlashAttention rejects those layers
        (head_size <= 256 kernel limit), causing vLLM to select a different
        backend for each layer type. This mixed-backend execution produces
        numerical divergence and output corruption.

        The fix detects heterogeneous head dimensions from the model config
        and forces TRITON_ATTN (which has no head_size ceiling) for all
        layers when the user hasn't explicitly chosen a backend.

        TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
        require NixlConnector changes to support per-layer KV transfer
        with different head dimensions for prefill-decode disaggregation.
        """
        hf_text_config = vllm_config.model_config.hf_text_config
        head_dim = getattr(hf_text_config, "head_dim", None)
        global_head_dim = getattr(hf_text_config, "global_head_dim", None)

        # Only force Triton when head dimensions actually differ AND the
        # larger one exceeds FlashAttention's kernel limit (head_size <= 256).
        # This avoids unnecessary backend forcing on smaller models where
        # the config carries global_head_dim but all layers can still use
        # the same FA backend.
        max_head_dim = max(head_dim or 0, global_head_dim or 0)
        if (
            head_dim is not None
            and global_head_dim is not None
            and head_dim != global_head_dim
            and max_head_dim > 256
            and vllm_config.attention_config.backend is None
        ):
            from vllm.v1.attention.backends.registry import (
                AttentionBackendEnum,
            )

            vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
            logger.info(
                "Gemma4 model has heterogeneous head dimensions "
                "(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
                "backend to prevent mixed-backend numerical divergence.",
                head_dim,
                global_head_dim,
            )


110
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        quant_config = getattr(model_config.hf_config, "quantization_config", None)
        if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
            model_config.hf_config.quantization_config["quant_method"] = "gpt_oss_mxfp4"

        hf_text_quant_config = getattr(
            model_config.hf_text_config, "quantization_config", None
        )
        if (
            hf_text_quant_config is not None
            and hf_text_quant_config.get("quant_method") == "mxfp4"
        ):
            model_config.hf_text_config.quantization_config["quant_method"] = (
                "gpt_oss_mxfp4"
            )

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"

        # Increase the max capture size from 512 to 1024 for performance.
        # NOTE(woosuk): This will increase the number of CUDA graphs
        # from 67 to 83.
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
            compilation_config.max_cudagraph_capture_size = 1024
            logger.info(
                "Overriding max cuda graph capture size to %d for performance.", 1024
            )


150
151
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
152
153
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
154
155
156
157
158
159
160

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
161
162
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
163
164
165
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
166
            "rope_parameters": config.rope_parameters,
167
168
169
        }


170
171
172
173
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
174
175
176
177
178
        Perform early validation and setup for hybrid attention/mamba models.

        Block size alignment with mamba page sizes is handled later by
        Platform.update_block_size_for_backend(), which runs after model
        layers are constructed and the attention backend is known.
179
180
181
182

        Args:
            vllm_config: vLLM Config
        """
183
184
185
186
187
        cache_config = vllm_config.cache_config

        # Disable calculate_kv_scales for hybrid models: uninitialized
        # recurrent state corrupts scales during the calibration pass.
        # See issue: https://github.com/vllm-project/vllm/issues/37554
188

189
190
191
192
193
194
195
196
197
198
199
        if cache_config.calculate_kv_scales:
            logger.warning(
                "Disabling calculate_kv_scales for hybrid model '%s'. "
                "Hybrid models with recurrent layers (GDN, Mamba, SSM) "
                "produce unreliable KV cache scales during the "
                "calibration pass because recurrent state is "
                "uninitialized. Using default scale of 1.0 instead.",
                vllm_config.model_config.model,
            )
            cache_config.calculate_kv_scales = False

200
201
202
203
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)


204
205
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
206
207
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
208
209
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
210
211


212
213
214
215
216
217
class JinaForRankingConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        model_config.hf_config.embedding_size = 512


218
219
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
220
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
221
        config = model_config.hf_config
222
223
224
225
226

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
227
228
229
230
231
232
233
234
235
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

236
237
238
            rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
            config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim

239
240
            config.rotary_kwargs = {
                "head_size": head_dim,
241
                "max_position": max_position,
242
                "rope_parameters": config.rope_parameters,
243
244
245
            }


246
247
248
249
250
251
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
        config.num_labels = 1
        pooler_config = model_config.pooler_config
252
253
        if pooler_config.logit_mean is None:
            pooler_config.logit_mean = 2.65
254
255


256
257
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
    @staticmethod
258
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
259
        from vllm.config.pooler import SequencePoolingType
260

261
        hf_config = model_config.hf_config
262
263
        hf_config.is_causal = False

264
        pooling_type_map: dict[str, SequencePoolingType] = {
265
266
267
268
269
270
271
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        pooling_type = pooling_type_map.get(hf_config.pooling, None)
        if pooling_type is None:
272
273
274
            raise ValueError(f"pool_type {hf_config.pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type
275
276


277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
    """Config handler for LlamaNemotronVL embedding models."""

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        from vllm.config.pooler import SequencePoolingType

        hf_config = model_config.hf_config

        # Set bidirectional attention on the language model config
        hf_config.is_causal = False
        if hasattr(hf_config, "llm_config"):
            hf_config.llm_config.is_causal = False

        if hasattr(hf_config, "vision_config"):
            hf_config.patch_size = hf_config.vision_config.patch_size

        # Set up pooling type
        pooling_type_map: dict[str, SequencePoolingType] = {
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        # Get pooling type from config (check both top-level and llm_config)
        pooling = getattr(hf_config, "pooling", None)
        if pooling is None and hasattr(hf_config, "llm_config"):
            pooling = getattr(hf_config.llm_config, "pooling", "avg")

        pooling_type = pooling_type_map.get(pooling)
        if pooling_type is None:
            raise ValueError(f"pool_type {pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type


313
314
315
316
317
318
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).
319

320
321
322
323
324
        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        if cache_config.enable_prefix_caching:
            if cache_config.mamba_cache_mode == "none":
                cache_config.mamba_cache_mode = (
                    "all" if model_config.supports_mamba_prefix_caching else "align"
                )
                logger.warning(
                    "Mamba cache mode is set to '%s' for %s by default "
                    "when prefix caching is enabled",
                    cache_config.mamba_cache_mode,
                    model_config.architecture,
                )
            if (
                cache_config.mamba_cache_mode == "all"
                and not model_config.supports_mamba_prefix_caching
            ):
                cache_config.mamba_cache_mode = "align"
                logger.warning(
                    "Hybrid or mamba-based model detected without support "
                    "for prefix caching with Mamba cache 'all' mode: "
                    "falling back to 'align' mode."
                )
            if cache_config.mamba_cache_mode == "align":
                assert vllm_config.scheduler_config.enable_chunked_prefill, (
                    "Chunked prefill is required for mamba cache mode 'align'."
                )
            logger.info(
                "Warning: Prefix caching in Mamba cache '%s' "
                "mode is currently enabled. "
                "Its support for Mamba layers is experimental. "
                "Please report any issues you may observe.",
                cache_config.mamba_cache_mode,
            )
            # By default, mamba block size will be set to max_model_len (see
            # below). When enabling prefix caching, we align mamba block size
            # to the block size as the basic granularity for prefix caching.
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = cache_config.block_size
        else:
            if cache_config.mamba_cache_mode != "none":
                cache_config.mamba_cache_mode = "none"
                logger.warning(
                    "Mamba cache mode is set to 'none' when prefix caching is disabled"
                )
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = model_config.max_model_len


class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
374
375
376
377
378
379
380
    DEFAULT_MAMBA_SSM_CACHE_DTYPE = "float32"
    """Only `float32` is known to have no accuracy issues by default."""

    @classmethod
    def update_mamba_ssm_cache_dtype(
        cls, *, cache_config: "CacheConfig", hf_config: "PretrainedConfig"
    ) -> None:
381
382
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
383
        `float32` if not specified.
384
385
386
        """
        if cache_config.mamba_ssm_cache_dtype == "auto":
            mamba_ssm_cache_dtype = getattr(
387
                hf_config, "mamba_ssm_cache_dtype", cls.DEFAULT_MAMBA_SSM_CACHE_DTYPE
388
389
390
391
392
393
394
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype

395
396
397
398
399
400
401
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        cls.update_mamba_ssm_cache_dtype(
            cache_config=vllm_config.cache_config,
            hf_config=vllm_config.model_config.hf_config,
        )

402
403

class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
404
405
406
407
408
409
410
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        NemotronHForCausalLMConfig.update_mamba_ssm_cache_dtype(
            cache_config=vllm_config.cache_config,
            hf_config=vllm_config.model_config.hf_config.text_config,
        )

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        mm_config = model_config.multimodal_config
        if mm_config is not None:
            video_kwargs = mm_config.media_io_kwargs.setdefault("video", {})
            video_kwargs.setdefault("video_backend", "nemotron_vl")


class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
443
444
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer
445
446
447
448
        model_config.model_arch_config.hidden_size = config.hidden_size
        model_config.model_arch_config.total_num_hidden_layers = (
            config.num_hidden_layers
        )
449
450
451

        head_dim = config.hidden_size // config.num_attention_heads
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
452

453
454
455
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": max_trained_positions,
456
            "rope_parameters": config.rope_parameters,
457
458
459
460
461
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
462
        # The context extension uses vllm style rope_theta and rope_parameters.
463
        # See #17785 #18755
464
        if (
465
466
            not model_config.hf_overrides
            and model_config.original_max_model_len is None
467
        ):
468
469
470
471
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
472
473
            max_model_len_before = model_config.max_model_len
            max_model_len = min(model_config.max_model_len, max_trained_positions)
474

475
476
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
477
            )
478
479
480
481
482
483

            if model_config.max_model_len != max_model_len_before:
                logger.warning(
                    "Nomic context extension is disabled. "
                    "Changing max_model_len from %s to %s. "
                    "To enable context extension, see: "
484
                    "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
485
486
487
                    max_model_len_before,
                    model_config.max_model_len,
                )
488
489
490
491
492
493
494
495
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
496
                    "max_model_len", model_config.max_model_len
497
                )
498
499
500
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
501
                max_model_len = model_config.max_model_len
502
503
504
505
506

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
507
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
508

509
510
511
512
513
514
            # Update the cached derived_max_model_len to enforce the limit
            model_config.model_arch_config.derived_max_model_len_and_key = (
                float(max_trained_positions),
                "max_position_embeddings",
            )

515
516
517
518
519
520
            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

521
522
523
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
            )
524
525


526
527
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
528
529
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
530
531
532
533
534
535
536

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
537
538
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
539

540
541
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
542
543


544
545
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
546
547
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
548

549
550
551
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
552
553
554
555
556

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
557
558
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
559
            "https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
560
        )
561
562
563
        text_config = config.get_text_config()
        text_config.method = "from_2_way_softmax"
        text_config.classifier_from_token = tokens
564
565


566
567
568
569
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
    pass


570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config's
        mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
        different value.
        """
        cache_config = vllm_config.cache_config
        hf_text_config = vllm_config.model_config.hf_text_config
        mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
        if cache_config.mamba_ssm_cache_dtype == "auto":
            if mamba_ssm_dtype is not None:
                cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
        elif (
            mamba_ssm_dtype is not None
            and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
        ):
            logger.warning(
                "Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
                "but --mamba-ssm-cache-dtype='%s' was passed. "
                "Using the user-specified value.",
                mamba_ssm_dtype,
                cache_config.mamba_ssm_cache_dtype,
            )


597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
            "rope_parameters": config.rope_parameters,
        }


chengchengpei's avatar
chengchengpei committed
617
618
619
620
621
622
623
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        model_config.hf_config.is_causal = False
        model_config.hf_config.embedding_size = model_config.hf_config.num_labels


624
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
625
    "ColBERTJinaRobertaModel": JinaRobertaModelConfig,
626
    "ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
627
628
629
630
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
    "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig,  # noqa: E501
    "FalconMambaForCausalLM": MambaModelConfig,
    "Gemma3TextModel": Gemma3TextModelConfig,
631
632
    "Gemma4ForCausalLM": Gemma4Config,
    "Gemma4ForConditionalGeneration": Gemma4Config,
633
    "GptOssForCausalLM": GptOssForCausalLMConfig,
634
    "GteModel": SnowflakeGteNewModelConfig,
635
    "GteNewForSequenceClassification": GteNewModelConfig,
636
637
    "GteNewModel": GteNewModelConfig,
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
638
    "JinaForRanking": JinaForRankingConfig,
639
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
640
641
    "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
    "LlamaBidirectionalModel": LlamaBidirectionalConfig,
642
    "LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
643
644
645
646
647
648
    "LlamaNemotronVLModel": LlamaNemotronVLConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
    "MambaForCausalLM": MambaModelConfig,
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
    "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
    "NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
649
    "NomicBertModel": NomicBertModelConfig,
650
651
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
652
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
653
    "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
654
655
    "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
    "Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
chengchengpei's avatar
chengchengpei committed
656
    "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
657
    "XLMRobertaModel": JinaRobertaModelConfig,
658
}