test_config.py 40.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
import os
6
from dataclasses import MISSING, Field, asdict, dataclass, field
7
from unittest.mock import patch
8

9
import pytest
10
from pydantic import ValidationError
11

12
from vllm.compilation.backends import VllmBackend
13
14
15
from vllm.config import (
    CompilationConfig,
    ModelConfig,
16
    ParallelConfig,
17
    PoolerConfig,
18
    SchedulerConfig,
19
    SpeculativeConfig,
20
21
22
23
    VllmConfig,
    update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
24
from vllm.config.load import LoadConfig
25
from vllm.config.utils import get_field
26
27
28
29
from vllm.config.vllm import (
    OPTIMIZATION_LEVEL_TO_CONFIG,
    OptimizationLevel,
)
30
from vllm.platforms import current_platform
31

32

33
34
35
36
37
38
39
40
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
    backend = VllmBackend(config)
    backend.configure_post_pass()

    # test that repr(config) succeeds
    val = repr(config)
41
42
    assert "VllmConfig" in val
    assert "inductor_passes" in val
43
44


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
    cfg = VllmConfig(
        scheduler_config=SchedulerConfig(
            max_model_len=8192,
            is_encoder_decoder=False,
            async_scheduling=True,
        ),
        parallel_config=ParallelConfig(
            pipeline_parallel_size=2,
            distributed_executor_backend="mp",
            nnodes=2,
        ),
    )
    assert cfg.scheduler_config.async_scheduling is True


61
62
63
64
65
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
66
67


68
def test_get_field():
69
    with pytest.raises(ValueError):
70
        get_field(_TestConfigFields, "a")
71

72
    b = get_field(_TestConfigFields, "b")
73
74
75
76
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

77
    c = get_field(_TestConfigFields, "c")
78
79
80
81
82
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


83
84
@dataclass
class _TestNestedConfig:
85
    a: _TestConfigFields = field(default_factory=lambda: _TestConfigFields(a=0))
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def test_update_config():
    # Simple update
    config1 = _TestConfigFields(a=0)
    new_config1 = update_config(config1, {"a": 42})
    assert new_config1.a == 42
    # Nonexistent field
    with pytest.raises(AssertionError):
        new_config1 = update_config(config1, {"nonexistent": 1})
    # Nested update with dataclass
    config2 = _TestNestedConfig()
    new_inner_config = _TestConfigFields(a=1, c="new_value")
    new_config2 = update_config(config2, {"a": new_inner_config})
    assert new_config2.a == new_inner_config
    # Nested update with dict
    config3 = _TestNestedConfig()
    new_config3 = update_config(config3, {"a": {"c": "new_value"}})
    assert new_config3.a.c == "new_value"
    # Nested update with invalid type
    with pytest.raises(AssertionError):
        new_config3 = update_config(config3, {"a": "new_value"})


110
111
112
113
114
115
116
117
118
119
120
121
122
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("distilbert/distilgpt2", "generate", "none"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "generate", "none"),
    ],
)
def test_auto_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="auto")
123
124

    assert config.runner_type == expected_runner_type
125
    assert config.convert_type == expected_convert_type
126
127
128


@pytest.mark.parametrize(
129
    ("model_id", "expected_runner_type", "expected_convert_type"),
130
    [
131
132
133
134
135
136
        ("distilbert/distilgpt2", "pooling", "embed"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "pooling", "embed"),
137
138
    ],
)
139
140
def test_pooling_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="pooling")
141
142

    assert config.runner_type == expected_runner_type
143
    assert config.convert_type == expected_convert_type
144
145


146
147
148
149
150
151
152
153
154
155
156
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("Qwen/Qwen2.5-1.5B-Instruct", "draft", "none"),
    ],
)
def test_draft_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="draft")

    assert config.runner_type == expected_runner_type
    assert config.convert_type == expected_convert_type
157
158


159
160
161
162
163
164
165
166
167
168
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
169
    model_config = ModelConfig(model_id, disable_sliding_window=True)
170
171
    assert model_config.max_model_len == expected

172

173
174
175
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
176
177
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
178
    model_config = ModelConfig(model_id)
179

180
    assert model_config.pooler_config is not None
181
    assert model_config.pooler_config.use_activation
182
183
    assert model_config.pooler_config.seq_pooling_type == "MEAN"
    assert model_config.pooler_config.tok_pooling_type == "ALL"
184
185


186
187
188
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
189
190
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
191
    pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
192
    model_config = ModelConfig(model_id, pooler_config=pooler_config)
193

194
    assert asdict(model_config.pooler_config) == asdict(pooler_config)
195
196


197
198
199
200
201
@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
        ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"),  # LLM
        ("intfloat/e5-small", "CLS", "MEAN"),  # BertModel
202
203
204
205
206
207
208
209
210
211
212
    ],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
    model_config = ModelConfig(model_id)
    assert model_config._model_info.default_seq_pooling_type == default_pooling_type
    assert model_config.pooler_config.seq_pooling_type == pooling_type


@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
213
        ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"),  # reward
214
215
216
        ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"),  # step reward
    ],
)
217
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
218
    model_config = ModelConfig(model_id)
219
220
    assert model_config._model_info.default_tok_pooling_type == default_pooling_type
    assert model_config.pooler_config.tok_pooling_type == pooling_type
221
222


223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
@pytest.mark.parametrize(
    ("model_id", "expected_is_moe_model"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
        ("RedHatAI/Llama-3.2-1B-FP8", False),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
    ],
)
def test_moe_model_detection(model_id, expected_is_moe_model):
    model_config = ModelConfig(model_id)
238
239
    # Just check that is_moe field exists and is a boolean
    assert model_config.is_moe == expected_is_moe_model
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256


@pytest.mark.parametrize(
    ("model_id", "quantized"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
        ("RedHatAI/Llama-3.2-1B-FP8", True),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
    ],
)
def test_is_quantized(model_id, quantized):
    model_config = ModelConfig(model_id)
    # Just check that quantized field exists and is a boolean
257
    assert model_config.is_quantized == quantized
258
259


260
261
262
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
263
def test_get_bert_tokenization_sentence_transformer_config():
264
265
    model_id = "BAAI/bge-base-en-v1.5"
    bge_model_config = ModelConfig(model_id)
266
267
268
269
270
271
272

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


273
def test_rope_customization():
274
275
276
277
278
279
280
    TEST_ROPE_PARAMETERS = {
        "rope_theta": 16_000_000.0,
        "rope_type": "dynamic",
        "factor": 2.0,
    }
    LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
    LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
281

282
    llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
283
284
285
286
    assert (
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == LLAMA_ROPE_PARAMETERS
    )
287
288
289
290
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
291
        hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
292
    )
293
    assert (
294
295
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
296
    )
297
298
    assert llama_model_config.max_model_len == 16384

299
    longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
300
    # Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
301
    assert all(
302
303
        longchat_model_config.hf_config.rope_parameters.get(key) == value
        for key, value in LONGCHAT_ROPE_PARAMETERS.items()
304
    )
305
306
307
308
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
309
        hf_overrides={
310
            "rope_parameters": TEST_ROPE_PARAMETERS,
311
        },
312
    )
313
    assert (
314
315
        getattr(longchat_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
316
    )
317
    assert longchat_model_config.max_model_len == 4096
318
319


320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def test_nested_hf_overrides():
    """Test that nested hf_overrides work correctly."""
    # Test with a model that has text_config
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 1024,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 1024

    # Test with deeply nested overrides
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 2048,
                "num_attention_heads": 16,
            },
            "vision_config": {
                "hidden_size": 512,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 2048
    assert model_config.hf_config.text_config.num_attention_heads == 16
    assert model_config.hf_config.vision_config.hidden_size == 512


351
352
353
354
355
356
357
358
359
360
361
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
)
@pytest.mark.parametrize(
    ("model_id", "is_encoder_decoder"),
    [
        ("facebook/opt-125m", False),
        ("openai/whisper-tiny", True),
        ("meta-llama/Llama-3.2-1B-Instruct", False),
    ],
)
362
def test_is_encoder_decoder(model_id, is_encoder_decoder):
363
    config = ModelConfig(model_id)
364
365
366
367

    assert config.is_encoder_decoder == is_encoder_decoder


368
369
370
371
372
373
374
@pytest.mark.parametrize(
    ("model_id", "uses_mrope"),
    [
        ("facebook/opt-125m", False),
        ("Qwen/Qwen2-VL-2B-Instruct", True),
    ],
)
375
def test_uses_mrope(model_id, uses_mrope):
376
    config = ModelConfig(model_id)
377
378

    assert config.uses_mrope == uses_mrope
379
380
381
382
383


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

384
    # When set generation_config to "vllm", the default generation config
385
    # will not be loaded.
386
    model_config = ModelConfig(model_id, generation_config="vllm")
387
388
389
390
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
391
    model_config = ModelConfig(model_id, generation_config="auto")
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        generation_config="auto",
408
409
        override_generation_config=override_generation_config,
    )
410
411
412
413
414
415

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

416
    # When generation_config is set to "vllm" and override_generation_config
417
418
419
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
420
        generation_config="vllm",
421
422
        override_generation_config=override_generation_config,
    )
423
424

    assert model_config.get_diff_sampling_param() == override_generation_config
425
426


427
428
429
430
431
432
433
@pytest.mark.parametrize(
    "pt_load_map_location",
    [
        "cuda",
        {"": "cuda"},
    ],
)
434
435
436
437
438
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
439
440
441


@pytest.mark.parametrize(
442
443
    ("model_id", "max_model_len", "expected_max_len", "should_raise"),
    [
444
445
446
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
447
448
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
449
450
    ],
)
451
def test_get_and_verify_max_len(
452
453
    model_id, max_model_len, expected_max_len, should_raise
):
454
    """Test get_and_verify_max_len with different configurations."""
455
    model_config = ModelConfig(model_id)
456
457
458

    if should_raise:
        with pytest.raises(ValueError):
459
            model_config.get_and_verify_max_len(max_model_len)
460
    else:
461
462
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len
463
464


465
466
class MockConfig:
    """Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
467

468
    def __init__(self, model: str, tokenizer: str):
469
        self.model = model
470
471
        self.tokenizer = tokenizer
        self.model_weights = None
472
473


474
475
476
477
478
479
480
481
@pytest.mark.parametrize(
    "s3_url",
    [
        "s3://example-bucket-1/model/",
        "s3://example-bucket-2/model/",
    ],
)
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
482
483
484
485
486
487
488
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
    """Test that S3 URLs create deterministic local directories for model and
    tokenizer."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    # Create first mock and run the method
489
490
    config1 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
491
492

    # Check that model and tokenizer point to existing directories
493
494
    assert os.path.exists(config1.model), (
        f"Model directory does not exist: {config1.model}"
495
    )
496
497
    assert os.path.isdir(config1.model), (
        f"Model path is not a directory: {config1.model}"
498
    )
499
500
    assert os.path.exists(config1.tokenizer), (
        f"Tokenizer directory does not exist: {config1.tokenizer}"
501
    )
502
503
    assert os.path.isdir(config1.tokenizer), (
        f"Tokenizer path is not a directory: {config1.tokenizer}"
504
    )
505
506

    # Verify that the paths are different from the original S3 URL
507
508
    assert config1.model != s3_url, "Model path should be converted to local directory"
    assert config1.tokenizer != s3_url, (
509
510
        "Tokenizer path should be converted to local directory"
    )
511
512

    # Store the original paths
513
514
    created_model_dir = config1.model
    create_tokenizer_dir = config1.tokenizer
515
516

    # Create a new mock and run the method with the same S3 URL
517
518
    config2 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
519
520

    # Check that the new directories exist
521
522
    assert os.path.exists(config2.model), (
        f"Model directory does not exist: {config2.model}"
523
    )
524
525
    assert os.path.isdir(config2.model), (
        f"Model path is not a directory: {config2.model}"
526
    )
527
528
    assert os.path.exists(config2.tokenizer), (
        f"Tokenizer directory does not exist: {config2.tokenizer}"
529
    )
530
531
    assert os.path.isdir(config2.tokenizer), (
        f"Tokenizer path is not a directory: {config2.tokenizer}"
532
    )
533
534

    # Verify that the paths are deterministic (same as before)
535
    assert config2.model == created_model_dir, (
536
        f"Model paths are not deterministic. "
537
        f"Original: {created_model_dir}, New: {config2.model}"
538
    )
539
    assert config2.tokenizer == create_tokenizer_dir, (
540
        f"Tokenizer paths are not deterministic. "
541
        f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}"
542
    )
543
544


545
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
546
547
548
549
550
551
552
553
554
def test_s3_url_different_models_create_different_directories(mock_pull_files):
    """Test that different S3 URLs create different local directories."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    s3_url1 = "s3://example-bucket-1/model/"
    s3_url2 = "s3://example-bucket-2/model/"

    # Create mocks with different S3 URLs and run the method
555
556
    config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
557

558
559
    config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
560
561

    # Verify that different URLs produce different directories
562
    assert config1.model != config2.model, (
563
        f"Different S3 URLs should create different model directories. "
564
        f"URL1 model: {config1.model}, URL2 model: {config2.model}"
565
    )
566
    assert config1.tokenizer != config2.tokenizer, (
567
        f"Different S3 URLs should create different tokenizer directories. "
568
569
        f"URL1 tokenizer: {config1.tokenizer}, "
        f"URL2 tokenizer: {config2.tokenizer}"
570
    )
571
572

    # Verify that both sets of directories exist
573
574
575
576
    assert os.path.exists(config1.model) and os.path.isdir(config1.model)
    assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
    assert os.path.exists(config2.model) and os.path.isdir(config2.model)
    assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
577
578


579
580
581
582
583
584
585
586
@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
587
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
588
589
590
591
592
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
593
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
594
595
596
597
598
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
599
            "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.",  # noqa: E501
600
601
602
603
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
604
            True,
605
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
606
607
608
609
610
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
611
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
612
613
614
615
616
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
617
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
618
619
620
621
622
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
623
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
624
625
626
627
628
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
629
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
630
631
632
633
634
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
635
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
636
637
638
639
640
641
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
642
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
643
644
645
646
647
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
648
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
649
650
651
652
653
654
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
655
            "Generative models support chunked prefill.",  # noqa: E501
656
657
658
659
660
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            True,
661
            "Generative models support chunked prefill.",  # noqa: E501
662
663
664
665
666
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            True,
667
            "Generative models support chunked prefill.",  # noqa: E501
668
669
670
671
672
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            True,
673
            "Generative models support chunked prefill.",  # noqa: E501
674
675
676
677
678
679
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
680
            "Encoder decoder models do not support chunked prefill.",  # noqa: E501
681
682
683
684
685
686
687
688
689
690
691
692
        ),
    ],
)
def test_is_chunked_prefill_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
693
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
694
695
696
697
698
699
700
701
702
703
704
705
        assert model_config.is_chunked_prefill_supported == expected_result
    assert reason in caplog_vllm.text


@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
706
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
707
708
709
710
711
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
712
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
713
714
715
716
717
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
718
            "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.",  # noqa: E501
719
720
721
722
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
723
            True,
724
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
725
726
727
728
729
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
730
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
731
732
733
734
735
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
736
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
737
738
739
740
741
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
742
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
743
744
745
746
747
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
748
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
749
750
751
752
753
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
754
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
755
756
757
758
759
760
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
761
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
762
763
764
765
766
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
767
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
768
769
770
771
772
773
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
774
            "Generative models support prefix caching.",  # noqa: E501
775
776
777
778
779
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            False,
780
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
781
782
783
784
785
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            False,
786
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
787
788
789
790
791
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            False,
792
            "Attention free models do not support prefix caching since the feature is still experimental.",  # noqa: E501
793
794
795
796
797
798
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
799
            "Encoder decoder models do not support prefix caching.",  # noqa: E501
800
801
802
803
804
805
806
807
808
809
810
811
        ),
    ],
)
def test_is_prefix_caching_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
812
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
813
814
815
816
        assert model_config.is_prefix_caching_supported == expected_result
    assert reason in caplog_vllm.text


817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
@pytest.mark.parametrize(
    ("backend", "custom_ops", "expected"),
    [
        ("eager", [], True),
        ("eager", ["+fused_layernorm"], True),
        ("eager", ["all", "-fused_layernorm"], False),
        ("inductor", [], False),
        ("inductor", ["none", "+fused_layernorm"], True),
        ("inductor", ["none", "-fused_layernorm"], False),
    ],
)
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
    """Test that is_custom_op_enabled works correctly."""
    config = VllmConfig(
        compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
    )
    assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected


def test_vllm_config_defaults_are_none():
    """Verify that optimization-level defaults are None when not set by user."""
    # Test all optimization levels to ensure defaults work correctly
    for opt_level in OptimizationLevel:
        config = object.__new__(VllmConfig)
        config.compilation_config = CompilationConfig()
        config.optimization_level = opt_level
        config.model_config = None

        # Use the global optimization level defaults
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]

        # Verify that all pass_config values are None before defaults are applied
        for pass_k in default_config["compilation_config"]["pass_config"]:
            assert getattr(config.compilation_config.pass_config, pass_k) is None

        # Verify that other config values are None before defaults are applied
        for k in default_config["compilation_config"]:
            if k != "pass_config":
                assert getattr(config.compilation_config, k) is None


@pytest.mark.parametrize(
    ("model_id", "compiliation_config", "optimization_level"),
    [
        (
            None,
            CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O0,
        ),
        (None, CompilationConfig(), OptimizationLevel.O0),
        (None, CompilationConfig(), OptimizationLevel.O1),
        (None, CompilationConfig(), OptimizationLevel.O2),
        (None, CompilationConfig(), OptimizationLevel.O3),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O0,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O1,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O3,
        ),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
    ],
)
def test_vllm_config_defaults(model_id, compiliation_config, optimization_level):
    """Test that optimization-level defaults are correctly applied."""

    model_config = None
    if model_id is not None:
        model_config = ModelConfig(model_id)
        vllm_config = VllmConfig(
            model_config=model_config,
            compilation_config=compiliation_config,
            optimization_level=optimization_level,
        )
    else:
        vllm_config = VllmConfig(
            compilation_config=compiliation_config,
            optimization_level=optimization_level,
        )
    # Use the global optimization level defaults
    default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]

    # Verify pass_config defaults (nested under compilation_config)
    pass_config_dict = default_config["compilation_config"]["pass_config"]
    for pass_k, pass_v in pass_config_dict.items():
        actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
        expected = pass_v(vllm_config) if callable(pass_v) else pass_v
        assert actual == expected, (
            f"pass_config.{pass_k}: expected {expected}, got {actual}"
        )

    # Verify other compilation_config defaults
    compilation_config_dict = default_config["compilation_config"]
    for k, v in compilation_config_dict.items():
        if k != "pass_config":
            actual = getattr(vllm_config.compilation_config, k)
            expected = v(vllm_config) if callable(v) else v
            assert actual == expected, (
                f"compilation_config.{k}: expected {expected}, got {actual}"
            )


def test_vllm_config_callable_defaults():
    """Test that callable defaults work in the config system.

    Verifies that lambdas in default configs can inspect VllmConfig properties
    (e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
    """
    config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)

    # Callable that checks if model exists
    has_model = lambda cfg: cfg.model_config is not None
    assert has_model(config_no_model) is False

    # Test with quantized model
    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    config_quantized = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_quantized = lambda cfg: (
958
        cfg.model_config is not None and cfg.model_config.is_quantized
959
960
961
962
963
964
965
966
967
968
    )
    assert enable_if_quantized(config_quantized) is True
    assert enable_if_quantized(config_no_model) is False

    # Test with MoE model
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    config_moe = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_sequential = lambda cfg: (
969
        cfg.model_config is not None and not cfg.model_config.is_moe
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    )
    assert enable_if_sequential(config_moe) is False
    assert enable_if_sequential(config_quantized) is True


def test_vllm_config_explicit_overrides():
    """Test that explicit property overrides work correctly with callable defaults.

    When users explicitly set configuration properties, those values
    take precedence over callable defaults, across different models and
    optimization levels.
    """
    from vllm.config.compilation import PassConfig

    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    regular_model = ModelConfig("Qwen/Qwen1.5-7B")

    # Explicit compilation mode override on O0 (where default is NONE)
    compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE

    # Explicit pass config flags to override defaults
998
    pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
999
1000
1001
1002
1003
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
1004
1005
    assert config.compilation_config.pass_config.eliminate_noops is True
    assert config.compilation_config.pass_config.fuse_attn_quant is True
1006
1007

    # Explicit cudagraph mode override on quantized model at O2
1008
    pass_config = PassConfig(fuse_gemm_comms=True)
1009
1010
1011
1012
1013
1014
1015
1016
1017
    compilation_config = CompilationConfig(
        cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
    )
    config = VllmConfig(
        model_config=quantized_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
1018
    assert config.compilation_config.pass_config.fuse_gemm_comms is True
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    # Mode should still use default for O2
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE

    # Different optimization levels with same model
    config_o0 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O0
    )
    config_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    assert config_o0.compilation_config.mode == CompilationMode.NONE
    assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert (
        config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Same optimization level across different model types
    config_moe_o2 = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    config_regular_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    config_quantized_o2 = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    # All should have same base compilation settings at O2
    assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert (
        config_moe_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )
    assert (
        config_regular_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Override one field but not others
1060
    pass_config = PassConfig(eliminate_noops=False)
1061
1062
1063
1064
1065
1066
1067
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        model_config=regular_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    # Explicit override should be respected
1068
    assert config.compilation_config.pass_config.eliminate_noops is False
1069
1070
1071
    # Other fields should still use defaults
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082


def test_scheduler_config_init():
    with pytest.raises(ValidationError):
        # Positional InitVars missing
        # (InitVars cannot have defaults otherwise they will become attributes)
        SchedulerConfig()

    with pytest.raises(AttributeError):
        # InitVar does not become an attribute
        print(SchedulerConfig.default_factory().max_model_len)
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125


@pytest.mark.parametrize(
    (
        "model_id",
        "data_parallel_size",
        "external_lb",
        "expected_needs_coordinator",
    ),
    [
        # Non-MoE model with DP=1 should not need coordinator
        ("facebook/opt-125m", 1, False, False),
        # Non-MoE model with DP>1 internal LB should need coordinator
        ("facebook/opt-125m", 2, False, True),
        # Non-MoE model with DP>1 external LB should not need coordinator
        ("facebook/opt-125m", 2, True, False),
        # MoE model with DP=1 should not need coordinator
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
        # MoE model with DP>1 internal LB should need both coordinator
        # and wave coordination
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
        # MoE model with DP>1 external LB needs coordinator for wave coordination
        # (wave coordination runs in coordinator process)
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
    ],
)
def test_needs_dp_coordination(
    model_id,
    data_parallel_size,
    external_lb,
    expected_needs_coordinator,
):
    """Test that DP coordinator and wave coordination are configured correctly."""
    from vllm.config import ParallelConfig

    model_config = ModelConfig(model_id)
    parallel_config = ParallelConfig(
        data_parallel_size=data_parallel_size,
        data_parallel_external_lb=external_lb,
    )
    vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)

    assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145


def test_eagle_draft_model_config():
    """Test that EagleDraft model config is correctly set."""
    target_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
    )
    speculative_config = SpeculativeConfig(
        model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
        num_speculative_tokens=1,
        target_model_config=target_model_config,
        target_parallel_config=ParallelConfig(),
    )
    draft_model_config = speculative_config.draft_model_config
    assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_config.model_type == "eagle"
    assert draft_model_config.hf_text_config.model_type == "eagle"
    assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.architecture == "EagleLlamaForCausalLM"