"requirements/docs.txt" did not exist on "98d6682cd1f27fa48bf915d3fd3e1eb1ee3014c4"
config.py 22.7 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING

from vllm.logger import init_logger
7
from vllm.utils.math_utils import round_up
8
9

if TYPE_CHECKING:
10
    from vllm.config import ModelConfig, VllmConfig
11
12
13
14
15
16
17

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
18
19
20
21
22
        return

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        return
23
24


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        hf_config = vllm_config.model_config.hf_config

        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
        is_v32 = hasattr(hf_config, "index_topk")
        assert is_v32

        cache_config = vllm_config.cache_config
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")


class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        # Ernie4.5-VL conditionally executes text/vision MoE branches, so
        # fast_moe_cold_start can silently produce incorrect execution order.
        vllm_config.compilation_config.fast_moe_cold_start = False


48
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
49
    @staticmethod
50
51
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        hf_config = model_config.hf_config
52
53
54
        hf_config.is_causal = not hf_config.use_bidirectional_attention


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        structured_outputs_config = vllm_config.structured_outputs_config
        if structured_outputs_config.reasoning_parser == "":
            structured_outputs_config.reasoning_parser = "openai_gptoss"

        # Increase the max capture size from 512 to 1024 for performance.
        # NOTE(woosuk): This will increase the number of CUDA graphs
        # from 67 to 83.
        compilation_config = vllm_config.compilation_config
        # Only override when the user has not set either of
        # cudagraph_capture_sizes or max_cudagraph_capture_size.
        if (
            compilation_config.cudagraph_capture_sizes is None
            and compilation_config.max_cudagraph_capture_size is None
        ):
            compilation_config.max_cudagraph_capture_size = 1024
            logger.info(
                "Overriding max cuda graph capture size to %d for performance.", 1024
            )


78
79
class GteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
80
81
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
82
83
84
85
86
87
88

        assert config.__class__.__name__ == "NewConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
89
90
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
91
92
93
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
94
            "rope_parameters": config.rope_parameters,
95
96
97
        }


98
99
100
101
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
102
103
104
105
106
        Perform early validation and setup for hybrid attention/mamba models.

        Block size alignment with mamba page sizes is handled later by
        Platform.update_block_size_for_backend(), which runs after model
        layers are constructed and the attention backend is known.
107
108
109
110

        Args:
            vllm_config: vLLM Config
        """
111
112
113
114
115
        cache_config = vllm_config.cache_config

        # Disable calculate_kv_scales for hybrid models: uninitialized
        # recurrent state corrupts scales during the calibration pass.
        # See issue: https://github.com/vllm-project/vllm/issues/37554
116

117
118
119
120
121
122
123
124
125
126
127
        if cache_config.calculate_kv_scales:
            logger.warning(
                "Disabling calculate_kv_scales for hybrid model '%s'. "
                "Hybrid models with recurrent layers (GDN, Mamba, SSM) "
                "produce unreliable KV cache scales during the "
                "calibration pass because recurrent state is "
                "uninitialized. Using default scale of 1.0 instead.",
                vllm_config.model_config.model,
            )
            cache_config.calculate_kv_scales = False

128
129
130
131
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)


132
133
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
134
135
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
136
137
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
138
139


140
141
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
    @staticmethod
142
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
143
        config = model_config.hf_config
144
145
146
147
148

        if config.position_embedding_type == "rotary":
            assert config.__class__.__name__ == "XLMRobertaFlashConfig"

            head_dim = config.hidden_size // config.num_attention_heads
149
150
151
152
153
154
155
156
157
            max_position = config.max_position_embeddings
            # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
            # out-of-bound index issue at RoPE for long prompts with torch.compile,
            # because it can't be divided by triton num_warps(default=4 or 8).
            # To deal with this, we increase max_position to multiple of n_warps,
            # so that triton kernel won't hit out-of-bound index in RoPE cache.
            if not model_config.enforce_eager:
                max_position = round_up(max_position, 8)

158
159
160
            rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
            config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim

161
162
            config.rotary_kwargs = {
                "head_size": head_dim,
163
                "max_position": max_position,
164
                "rope_parameters": config.rope_parameters,
165
166
167
            }


168
169
170
171
172
173
174
175
176
177
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
        config.num_labels = 1
        pooler_config = model_config.pooler_config
        if pooler_config.logit_bias is None:
            pooler_config.logit_bias = 2.65


178
179
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
    @staticmethod
180
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
181
        from vllm.config.pooler import SequencePoolingType
182

183
        hf_config = model_config.hf_config
184
185
        hf_config.is_causal = False

186
        pooling_type_map: dict[str, SequencePoolingType] = {
187
188
189
190
191
192
193
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        pooling_type = pooling_type_map.get(hf_config.pooling, None)
        if pooling_type is None:
194
195
196
            raise ValueError(f"pool_type {hf_config.pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type
197
198


199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
    """Config handler for LlamaNemotronVL embedding models."""

    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        from vllm.config.pooler import SequencePoolingType

        hf_config = model_config.hf_config

        # Set bidirectional attention on the language model config
        hf_config.is_causal = False
        if hasattr(hf_config, "llm_config"):
            hf_config.llm_config.is_causal = False

        if hasattr(hf_config, "vision_config"):
            hf_config.patch_size = hf_config.vision_config.patch_size

        # Set up pooling type
        pooling_type_map: dict[str, SequencePoolingType] = {
            "avg": "MEAN",
            "cls": "CLS",
            "last": "LAST",
        }

        # Get pooling type from config (check both top-level and llm_config)
        pooling = getattr(hf_config, "pooling", None)
        if pooling is None and hasattr(hf_config, "llm_config"):
            pooling = getattr(hf_config.llm_config, "pooling", "avg")

        pooling_type = pooling_type_map.get(pooling)
        if pooling_type is None:
            raise ValueError(f"pool_type {pooling!r} not supported")

        model_config.pooler_config.seq_pooling_type = pooling_type


235
236
237
238
239
240
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).
241

242
243
244
245
246
        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        if cache_config.enable_prefix_caching:
            if cache_config.mamba_cache_mode == "none":
                cache_config.mamba_cache_mode = (
                    "all" if model_config.supports_mamba_prefix_caching else "align"
                )
                logger.warning(
                    "Mamba cache mode is set to '%s' for %s by default "
                    "when prefix caching is enabled",
                    cache_config.mamba_cache_mode,
                    model_config.architecture,
                )
            if (
                cache_config.mamba_cache_mode == "all"
                and not model_config.supports_mamba_prefix_caching
            ):
                cache_config.mamba_cache_mode = "align"
                logger.warning(
                    "Hybrid or mamba-based model detected without support "
                    "for prefix caching with Mamba cache 'all' mode: "
                    "falling back to 'align' mode."
                )
            if cache_config.mamba_cache_mode == "align":
                assert vllm_config.scheduler_config.enable_chunked_prefill, (
                    "Chunked prefill is required for mamba cache mode 'align'."
                )
            logger.info(
                "Warning: Prefix caching in Mamba cache '%s' "
                "mode is currently enabled. "
                "Its support for Mamba layers is experimental. "
                "Please report any issues you may observe.",
                cache_config.mamba_cache_mode,
            )
            # By default, mamba block size will be set to max_model_len (see
            # below). When enabling prefix caching, we align mamba block size
            # to the block size as the basic granularity for prefix caching.
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = cache_config.block_size
        else:
            if cache_config.mamba_cache_mode != "none":
                cache_config.mamba_cache_mode = "none"
                logger.warning(
                    "Mamba cache mode is set to 'none' when prefix caching is disabled"
                )
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = model_config.max_model_len


class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        mm_config = model_config.multimodal_config
        if mm_config is not None:
            video_kwargs = mm_config.media_io_kwargs.setdefault("video", {})
            video_kwargs.setdefault("video_backend", "nemotron_vl")


class NomicBertModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function in ["swiglu", "gelu"]
        config.position_embedding_type = getattr(
            config, "position_embedding_type", "rope"
        )

        if config.activation_function == "swiglu":
            config.hidden_act = "silu"
        else:
            config.hidden_act = config.activation_function

        assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
        config.bias = config.qkv_proj_bias

        assert config.rotary_emb_scale_base is None
        assert not config.rotary_emb_interleaved

        config.layer_norm_eps = config.layer_norm_epsilon
        config.intermediate_size = config.n_inner
348
349
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer
350
351
352
353
        model_config.model_arch_config.hidden_size = config.hidden_size
        model_config.model_arch_config.total_num_hidden_layers = (
            config.num_hidden_layers
        )
354
355
356

        head_dim = config.hidden_size // config.num_attention_heads
        max_trained_positions = getattr(config, "max_trained_positions", 2048)
357

358
359
360
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": max_trained_positions,
361
            "rope_parameters": config.rope_parameters,
362
363
364
365
366
        }

        # we ignore config.rotary_scaling_factor so that for datasets shorter
        # than max_trained_positions 2048, the results are consistent
        # with SentenceTransformer.
367
        # The context extension uses vllm style rope_theta and rope_parameters.
368
        # See #17785 #18755
369
        if (
370
371
            not model_config.hf_overrides
            and model_config.original_max_model_len is None
372
        ):
373
374
375
376
            # Default
            # Reset max_model_len to max_trained_positions.
            # nomic-embed-text-v2-moe the length is set to 512
            # by sentence_bert_config.json.
377
378
            max_model_len_before = model_config.max_model_len
            max_model_len = min(model_config.max_model_len, max_trained_positions)
379

380
381
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
382
            )
383
384
385
386
387
388

            if model_config.max_model_len != max_model_len_before:
                logger.warning(
                    "Nomic context extension is disabled. "
                    "Changing max_model_len from %s to %s. "
                    "To enable context extension, see: "
389
                    "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
390
391
392
                    max_model_len_before,
                    model_config.max_model_len,
                )
393
394
395
396
397
398
399
400
        else:
            # We need to re-verify max_model_len to avoid lengths
            # greater than position_embedding.
            hf_text_config = model_config.hf_text_config

            if isinstance(model_config.hf_overrides, dict):
                # hf_overrides_kw
                max_model_len = model_config.hf_overrides.get(
401
                    "max_model_len", model_config.max_model_len
402
                )
403
404
405
            else:
                # hf_overrides_fn
                # This might be overridden by sentence_bert_config.json.
406
                max_model_len = model_config.max_model_len
407
408
409
410
411

            # reset hf_text_config for recalculate_max_model_len.
            if hasattr(hf_text_config, "max_model_len"):
                delattr(hf_text_config, "max_model_len")
            hf_text_config.max_position_embeddings = max_trained_positions
412
            hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
413

414
415
416
417
418
419
            # Update the cached derived_max_model_len to enforce the limit
            model_config.model_arch_config.derived_max_model_len_and_key = (
                float(max_trained_positions),
                "max_position_embeddings",
            )

420
421
422
423
424
425
            # The priority of sentence_bert_config.json is higher
            # than max_position_embeddings
            encoder_config = deepcopy(model_config.encoder_config)
            encoder_config.pop("max_seq_length", None)
            model_config.encoder_config = encoder_config

426
427
428
            model_config.max_model_len = model_config.get_and_verify_max_len(
                max_model_len
            )
429
430


431
432
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
433
434
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
435
436
437
438
439
440
441

        if pooler_config.step_tag_id is None:
            pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
    @staticmethod
442
443
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        pooler_config = model_config.pooler_config
444

445
446
        if pooler_config.use_activation is None:
            pooler_config.use_activation = False
447
448


449
450
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
    @staticmethod
451
452
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config
453

454
455
456
        is_original_qwen3_reranker = getattr(
            config, "is_original_qwen3_reranker", False
        )
457
458
459
460
461

        if not is_original_qwen3_reranker:
            return

        tokens = getattr(config, "classifier_from_token", None)
462
463
        assert tokens is not None and len(tokens) == 2, (
            "Try loading the original Qwen3 Reranker?, see: "
464
            "https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
465
        )
466
467
468
        text_config = config.get_text_config()
        text_config.method = "from_2_way_softmax"
        text_config.classifier_from_token = tokens
469
470


471
472
473
474
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
    pass


475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config's
        mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
        different value.
        """
        cache_config = vllm_config.cache_config
        hf_text_config = vllm_config.model_config.hf_text_config
        mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
        if cache_config.mamba_ssm_cache_dtype == "auto":
            if mamba_ssm_dtype is not None:
                cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
        elif (
            mamba_ssm_dtype is not None
            and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
        ):
            logger.warning(
                "Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
                "but --mamba-ssm-cache-dtype='%s' was passed. "
                "Using the user-specified value.",
                mamba_ssm_dtype,
                cache_config.mamba_ssm_cache_dtype,
            )


502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        config = model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.hidden_act == "gelu"

        config.hidden_act = "geglu"

        head_dim = config.hidden_size // config.num_attention_heads
        rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
        config.rotary_kwargs = {
            "head_size": head_dim,
            "max_position": config.max_position_embeddings,
            "rope_parameters": config.rope_parameters,
        }


chengchengpei's avatar
chengchengpei committed
522
523
524
525
526
527
528
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_model_config(model_config: "ModelConfig") -> None:
        model_config.hf_config.is_causal = False
        model_config.hf_config.embedding_size = model_config.hf_config.num_labels


529
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
530
    "ColBERTJinaRobertaModel": JinaRobertaModelConfig,
531
    "ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
532
533
534
535
536
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
    "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig,  # noqa: E501
    "FalconMambaForCausalLM": MambaModelConfig,
    "Gemma3TextModel": Gemma3TextModelConfig,
    "GptOssForCausalLM": GptOssForCausalLMConfig,
537
    "GteModel": SnowflakeGteNewModelConfig,
538
    "GteNewForSequenceClassification": GteNewModelConfig,
539
540
541
    "GteNewModel": GteNewModelConfig,
    "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
    "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
542
543
    "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
    "LlamaBidirectionalModel": LlamaBidirectionalConfig,
544
    "LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
545
546
547
548
549
550
    "LlamaNemotronVLModel": LlamaNemotronVLConfig,
    "Mamba2ForCausalLM": MambaModelConfig,
    "MambaForCausalLM": MambaModelConfig,
    "NemotronHForCausalLM": NemotronHForCausalLMConfig,
    "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
    "NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
551
    "NomicBertModel": NomicBertModelConfig,
552
553
    "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
    "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
554
    "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
555
    "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
556
557
    "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
    "Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
chengchengpei's avatar
chengchengpei committed
558
    "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
559
    "XLMRobertaModel": JinaRobertaModelConfig,
560
}