test_config.py 46.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
import os
6
from dataclasses import MISSING, Field, asdict, dataclass, field
7
from types import SimpleNamespace
8
from unittest.mock import patch
9

10
import pydantic
11
import pytest
12
from pydantic import ValidationError
13

14
import vllm.config.vllm as vllm_config_module
15
from vllm.compilation.backends import VllmBackend
16
17
from vllm.config import (
    CompilationConfig,
18
    KernelConfig,
19
    ModelConfig,
20
    ParallelConfig,
21
    PoolerConfig,
22
    SchedulerConfig,
23
    SpeculativeConfig,
24
25
26
27
    VllmConfig,
    update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
28
from vllm.config.kernel import IrOpPriorityConfig
29
from vllm.config.load import LoadConfig
30
from vllm.config.utils import get_field
31
32
33
34
from vllm.config.vllm import (
    OPTIMIZATION_LEVEL_TO_CONFIG,
    OptimizationLevel,
)
35
from vllm.platforms import current_platform
36

37

38
39
40
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
41
42
    backend = VllmBackend(config)
    backend.configure_post_pass()
43
44
45

    # test that repr(config) succeeds
    val = repr(config)
46
47
    assert "VllmConfig" in val
    assert "inductor_passes" in val
48
49


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@pytest.mark.skip_global_cleanup
def test_with_hf_config_populates_missing_architectures_from_causal_lm_mapping(
    monkeypatch,
):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(model_type="mistral", architectures=None)

    updated = VllmConfig.with_hf_config(cfg, hf_config)

    assert updated.model_config.hf_config.architectures == ["MistralForCausalLM"]
    assert hf_config.architectures is None


@pytest.mark.skip_global_cleanup
def test_with_hf_config_preserves_explicit_architectures_override(monkeypatch):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(model_type="mistral", architectures=None)

    updated = VllmConfig.with_hf_config(
        cfg,
        hf_config,
        architectures=["Ministral3ForCausalLM"],
    )

    assert updated.model_config.hf_config.architectures == ["Ministral3ForCausalLM"]


@pytest.mark.skip_global_cleanup
def test_with_hf_config_leaves_unknown_model_type_without_architectures(
    monkeypatch,
):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(
        model_type="not_a_real_model",
        architectures=None,
    )

    updated = VllmConfig.with_hf_config(cfg, hf_config)

    assert updated.model_config.hf_config.architectures is None


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
    cfg = VllmConfig(
        scheduler_config=SchedulerConfig(
            max_model_len=8192,
            is_encoder_decoder=False,
            async_scheduling=True,
        ),
        parallel_config=ParallelConfig(
            pipeline_parallel_size=2,
            distributed_executor_backend="mp",
            nnodes=2,
        ),
    )
    assert cfg.scheduler_config.async_scheduling is True


141
142
143
144
145
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
146
147


148
149
def test_get_field():
    b = get_field(_TestConfigFields, "b")
150
151
152
153
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

154
    c = get_field(_TestConfigFields, "c")
155
156
157
158
159
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


160
161
@dataclass
class _TestNestedConfig:
162
    a: _TestConfigFields = field(default_factory=lambda: _TestConfigFields(a=0))
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186


def test_update_config():
    # Simple update
    config1 = _TestConfigFields(a=0)
    new_config1 = update_config(config1, {"a": 42})
    assert new_config1.a == 42
    # Nonexistent field
    with pytest.raises(AssertionError):
        new_config1 = update_config(config1, {"nonexistent": 1})
    # Nested update with dataclass
    config2 = _TestNestedConfig()
    new_inner_config = _TestConfigFields(a=1, c="new_value")
    new_config2 = update_config(config2, {"a": new_inner_config})
    assert new_config2.a == new_inner_config
    # Nested update with dict
    config3 = _TestNestedConfig()
    new_config3 = update_config(config3, {"a": {"c": "new_value"}})
    assert new_config3.a.c == "new_value"
    # Nested update with invalid type
    with pytest.raises(AssertionError):
        new_config3 = update_config(config3, {"a": "new_value"})


187
188
189
190
191
192
193
194
195
196
197
198
199
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("distilbert/distilgpt2", "generate", "none"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "generate", "none"),
    ],
)
def test_auto_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="auto")
200
201

    assert config.runner_type == expected_runner_type
202
    assert config.convert_type == expected_convert_type
203
204
205


@pytest.mark.parametrize(
206
    ("model_id", "expected_runner_type", "expected_convert_type"),
207
    [
208
209
210
211
212
213
        ("distilbert/distilgpt2", "pooling", "embed"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "pooling", "embed"),
214
215
    ],
)
216
217
def test_pooling_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="pooling")
218
219

    assert config.runner_type == expected_runner_type
220
    assert config.convert_type == expected_convert_type
221
222


223
224
225
226
227
228
229
230
231
232
233
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("Qwen/Qwen2.5-1.5B-Instruct", "draft", "none"),
    ],
)
def test_draft_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="draft")

    assert config.runner_type == expected_runner_type
    assert config.convert_type == expected_convert_type
234
235


236
237
238
239
240
241
242
243
244
245
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
246
    model_config = ModelConfig(model_id, disable_sliding_window=True)
247
248
    assert model_config.max_model_len == expected

249

250
251
252
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
253
254
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
255
    model_config = ModelConfig(model_id)
256

257
    assert model_config.pooler_config is not None
258
    assert model_config.pooler_config.use_activation
259
260
    assert model_config.pooler_config.seq_pooling_type == "MEAN"
    assert model_config.pooler_config.tok_pooling_type == "ALL"
261
262


263
264
265
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
266
267
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
268
    pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
269
    model_config = ModelConfig(model_id, pooler_config=pooler_config)
270

271
    assert asdict(model_config.pooler_config) == asdict(pooler_config)
272
273


274
275
276
277
278
@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
        ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"),  # LLM
        ("intfloat/e5-small", "CLS", "MEAN"),  # BertModel
279
280
281
282
283
284
285
286
287
288
289
    ],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
    model_config = ModelConfig(model_id)
    assert model_config._model_info.default_seq_pooling_type == default_pooling_type
    assert model_config.pooler_config.seq_pooling_type == pooling_type


@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
290
        ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"),  # reward
291
292
293
        ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"),  # step reward
    ],
)
294
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
295
    model_config = ModelConfig(model_id)
296
297
    assert model_config._model_info.default_tok_pooling_type == default_pooling_type
    assert model_config.pooler_config.tok_pooling_type == pooling_type
298
299


300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@pytest.mark.parametrize(
    ("model_id", "expected_is_moe_model"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
        ("RedHatAI/Llama-3.2-1B-FP8", False),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
    ],
)
def test_moe_model_detection(model_id, expected_is_moe_model):
    model_config = ModelConfig(model_id)
315
316
    # Just check that is_moe field exists and is a boolean
    assert model_config.is_moe == expected_is_moe_model
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333


@pytest.mark.parametrize(
    ("model_id", "quantized"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
        ("RedHatAI/Llama-3.2-1B-FP8", True),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
    ],
)
def test_is_quantized(model_id, quantized):
    model_config = ModelConfig(model_id)
    # Just check that quantized field exists and is a boolean
334
    assert model_config.is_quantized == quantized
335
336


337
338
339
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
340
def test_get_bert_tokenization_sentence_transformer_config():
341
342
    model_id = "BAAI/bge-base-en-v1.5"
    bge_model_config = ModelConfig(model_id)
343
344
345
346
347
348
349

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


350
def test_rope_customization():
351
352
353
354
355
356
357
    TEST_ROPE_PARAMETERS = {
        "rope_theta": 16_000_000.0,
        "rope_type": "dynamic",
        "factor": 2.0,
    }
    LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
    LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
358

359
    llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
360
361
362
363
    assert (
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == LLAMA_ROPE_PARAMETERS
    )
364
365
366
367
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
368
        hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
369
    )
370
    assert (
371
372
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
373
    )
374
375
    assert llama_model_config.max_model_len == 16384

376
    longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
377
    # Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
378
    assert all(
379
380
        longchat_model_config.hf_config.rope_parameters.get(key) == value
        for key, value in LONGCHAT_ROPE_PARAMETERS.items()
381
    )
382
383
384
385
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
386
        hf_overrides={
387
            "rope_parameters": TEST_ROPE_PARAMETERS,
388
        },
389
    )
390
    assert (
391
392
        getattr(longchat_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
393
    )
394
    assert longchat_model_config.max_model_len == 4096
395
396


397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def test_nested_hf_overrides():
    """Test that nested hf_overrides work correctly."""
    # Test with a model that has text_config
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 1024,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 1024

    # Test with deeply nested overrides
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 2048,
                "num_attention_heads": 16,
            },
            "vision_config": {
                "hidden_size": 512,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 2048
    assert model_config.hf_config.text_config.num_attention_heads == 16
    assert model_config.hf_config.vision_config.hidden_size == 512


428
429
430
431
432
433
434
435
436
437
438
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
)
@pytest.mark.parametrize(
    ("model_id", "is_encoder_decoder"),
    [
        ("facebook/opt-125m", False),
        ("openai/whisper-tiny", True),
        ("meta-llama/Llama-3.2-1B-Instruct", False),
    ],
)
439
def test_is_encoder_decoder(model_id, is_encoder_decoder):
440
    config = ModelConfig(model_id)
441
442
443
444

    assert config.is_encoder_decoder == is_encoder_decoder


445
446
447
448
449
450
451
@pytest.mark.parametrize(
    ("model_id", "uses_mrope"),
    [
        ("facebook/opt-125m", False),
        ("Qwen/Qwen2-VL-2B-Instruct", True),
    ],
)
452
def test_uses_mrope(model_id, uses_mrope):
453
    config = ModelConfig(model_id)
454
455

    assert config.uses_mrope == uses_mrope
456
457
458
459
460


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

461
    # When set generation_config to "vllm", the default generation config
462
    # will not be loaded.
463
    model_config = ModelConfig(model_id, generation_config="vllm")
464
465
466
467
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
468
    model_config = ModelConfig(model_id, generation_config="auto")
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        generation_config="auto",
485
486
        override_generation_config=override_generation_config,
    )
487
488
489
490
491
492

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

493
    # When generation_config is set to "vllm" and override_generation_config
494
495
496
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
497
        generation_config="vllm",
498
499
        override_generation_config=override_generation_config,
    )
500
501

    assert model_config.get_diff_sampling_param() == override_generation_config
502
503


504
505
506
507
508
509
510
@pytest.mark.parametrize(
    "pt_load_map_location",
    [
        "cuda",
        {"": "cuda"},
    ],
)
511
512
513
514
515
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
516
517
518


@pytest.mark.parametrize(
519
520
    ("model_id", "max_model_len", "expected_max_len", "should_raise"),
    [
521
522
523
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
524
525
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
526
527
    ],
)
528
def test_get_and_verify_max_len(
529
530
    model_id, max_model_len, expected_max_len, should_raise
):
531
    """Test get_and_verify_max_len with different configurations."""
532
    model_config = ModelConfig(model_id)
533
534
535

    if should_raise:
        with pytest.raises(ValueError):
536
            model_config.get_and_verify_max_len(max_model_len)
537
    else:
538
539
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len
540
541


542
543
class MockConfig:
    """Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
544

545
    def __init__(self, model: str, tokenizer: str):
546
        self.model = model
547
548
        self.tokenizer = tokenizer
        self.model_weights = None
549
550


551
552
553
554
555
556
557
558
@pytest.mark.parametrize(
    "s3_url",
    [
        "s3://example-bucket-1/model/",
        "s3://example-bucket-2/model/",
    ],
)
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
559
560
561
562
563
564
565
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
    """Test that S3 URLs create deterministic local directories for model and
    tokenizer."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    # Create first mock and run the method
566
567
    config1 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
568
569

    # Check that model and tokenizer point to existing directories
570
571
    assert os.path.exists(config1.model), (
        f"Model directory does not exist: {config1.model}"
572
    )
573
574
    assert os.path.isdir(config1.model), (
        f"Model path is not a directory: {config1.model}"
575
    )
576
577
    assert os.path.exists(config1.tokenizer), (
        f"Tokenizer directory does not exist: {config1.tokenizer}"
578
    )
579
580
    assert os.path.isdir(config1.tokenizer), (
        f"Tokenizer path is not a directory: {config1.tokenizer}"
581
    )
582
583

    # Verify that the paths are different from the original S3 URL
584
585
    assert config1.model != s3_url, "Model path should be converted to local directory"
    assert config1.tokenizer != s3_url, (
586
587
        "Tokenizer path should be converted to local directory"
    )
588
589

    # Store the original paths
590
591
    created_model_dir = config1.model
    create_tokenizer_dir = config1.tokenizer
592
593

    # Create a new mock and run the method with the same S3 URL
594
595
    config2 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
596
597

    # Check that the new directories exist
598
599
    assert os.path.exists(config2.model), (
        f"Model directory does not exist: {config2.model}"
600
    )
601
602
    assert os.path.isdir(config2.model), (
        f"Model path is not a directory: {config2.model}"
603
    )
604
605
    assert os.path.exists(config2.tokenizer), (
        f"Tokenizer directory does not exist: {config2.tokenizer}"
606
    )
607
608
    assert os.path.isdir(config2.tokenizer), (
        f"Tokenizer path is not a directory: {config2.tokenizer}"
609
    )
610
611

    # Verify that the paths are deterministic (same as before)
612
    assert config2.model == created_model_dir, (
613
        f"Model paths are not deterministic. "
614
        f"Original: {created_model_dir}, New: {config2.model}"
615
    )
616
    assert config2.tokenizer == create_tokenizer_dir, (
617
        f"Tokenizer paths are not deterministic. "
618
        f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}"
619
    )
620
621


622
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
623
624
625
626
627
628
629
630
631
def test_s3_url_different_models_create_different_directories(mock_pull_files):
    """Test that different S3 URLs create different local directories."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    s3_url1 = "s3://example-bucket-1/model/"
    s3_url2 = "s3://example-bucket-2/model/"

    # Create mocks with different S3 URLs and run the method
632
633
    config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
634

635
636
    config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
637
638

    # Verify that different URLs produce different directories
639
    assert config1.model != config2.model, (
640
        f"Different S3 URLs should create different model directories. "
641
        f"URL1 model: {config1.model}, URL2 model: {config2.model}"
642
    )
643
    assert config1.tokenizer != config2.tokenizer, (
644
        f"Different S3 URLs should create different tokenizer directories. "
645
646
        f"URL1 tokenizer: {config1.tokenizer}, "
        f"URL2 tokenizer: {config2.tokenizer}"
647
    )
648
649

    # Verify that both sets of directories exist
650
651
652
653
    assert os.path.exists(config1.model) and os.path.isdir(config1.model)
    assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
    assert os.path.exists(config2.model) and os.path.isdir(config2.model)
    assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
654
655


656
657
658
659
660
661
662
663
@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
664
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
665
666
667
668
669
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
670
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
671
672
673
674
675
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
676
            "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.",  # noqa: E501
677
678
679
680
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
681
            True,
682
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
683
684
685
686
687
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
688
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
689
690
691
692
693
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
694
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
695
696
697
698
699
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
700
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
701
702
703
704
705
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
706
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
707
708
709
710
711
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
712
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
713
714
715
716
717
718
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
719
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
720
721
722
723
724
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
725
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
726
727
728
729
730
731
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
732
            "Generative models support chunked prefill.",  # noqa: E501
733
734
735
736
737
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            True,
738
            "Generative models support chunked prefill.",  # noqa: E501
739
740
741
742
743
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            True,
744
            "Generative models support chunked prefill.",  # noqa: E501
745
746
747
748
749
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            True,
750
            "Generative models support chunked prefill.",  # noqa: E501
751
752
753
754
755
756
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
757
            "Encoder decoder models do not support chunked prefill.",  # noqa: E501
758
759
760
761
762
763
764
765
766
767
768
769
        ),
    ],
)
def test_is_chunked_prefill_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
770
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
771
772
773
774
775
776
777
778
779
780
781
782
        assert model_config.is_chunked_prefill_supported == expected_result
    assert reason in caplog_vllm.text


@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
783
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
784
785
786
787
788
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
789
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
790
791
792
793
794
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
795
            "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.",  # noqa: E501
796
797
798
799
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
800
            True,
801
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
802
803
804
805
806
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
807
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
808
809
810
811
812
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
813
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
814
815
816
817
818
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
819
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
820
821
822
823
824
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
825
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
826
827
828
829
830
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
831
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
832
833
834
835
836
837
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
838
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
839
840
841
842
843
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
844
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
845
846
847
848
849
850
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
851
            "Generative models support prefix caching.",  # noqa: E501
852
853
854
855
856
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            False,
857
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
858
859
860
861
862
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            False,
863
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
864
865
866
867
868
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            False,
869
            "Attention free models do not support prefix caching since the feature is still experimental.",  # noqa: E501
870
871
872
873
874
875
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
876
            "Encoder decoder models do not support prefix caching.",  # noqa: E501
877
878
879
880
881
882
883
884
885
886
887
888
        ),
    ],
)
def test_is_prefix_caching_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
889
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
890
891
892
893
        assert model_config.is_prefix_caching_supported == expected_result
    assert reason in caplog_vllm.text


894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
@pytest.mark.parametrize(
    ("backend", "custom_ops", "expected"),
    [
        ("eager", [], True),
        ("eager", ["+fused_layernorm"], True),
        ("eager", ["all", "-fused_layernorm"], False),
        ("inductor", [], False),
        ("inductor", ["none", "+fused_layernorm"], True),
        ("inductor", ["none", "-fused_layernorm"], False),
    ],
)
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
    """Test that is_custom_op_enabled works correctly."""
    config = VllmConfig(
        compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
    )
    assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected


def test_vllm_config_defaults_are_none():
    """Verify that optimization-level defaults are None when not set by user."""
    # Test all optimization levels to ensure defaults work correctly
    for opt_level in OptimizationLevel:
        config = object.__new__(VllmConfig)
        config.compilation_config = CompilationConfig()
        config.optimization_level = opt_level
        config.model_config = None

        # Use the global optimization level defaults
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]

        # Verify that all pass_config values are None before defaults are applied
        for pass_k in default_config["compilation_config"]["pass_config"]:
            assert getattr(config.compilation_config.pass_config, pass_k) is None

        # Verify that other config values are None before defaults are applied
        for k in default_config["compilation_config"]:
            if k != "pass_config":
                assert getattr(config.compilation_config, k) is None


@pytest.mark.parametrize(
Jiayi Yan's avatar
Jiayi Yan committed
936
    ("model_id", "compilation_config", "optimization_level"),
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
    [
        (
            None,
            CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O0,
        ),
        (None, CompilationConfig(), OptimizationLevel.O0),
        (None, CompilationConfig(), OptimizationLevel.O1),
        (None, CompilationConfig(), OptimizationLevel.O2),
        (None, CompilationConfig(), OptimizationLevel.O3),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O0,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O1,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O3,
        ),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
    ],
)
Jiayi Yan's avatar
Jiayi Yan committed
978
def test_vllm_config_defaults(model_id, compilation_config, optimization_level):
979
980
981
982
983
984
985
    """Test that optimization-level defaults are correctly applied."""

    model_config = None
    if model_id is not None:
        model_config = ModelConfig(model_id)
        vllm_config = VllmConfig(
            model_config=model_config,
Jiayi Yan's avatar
Jiayi Yan committed
986
            compilation_config=compilation_config,
987
988
989
990
            optimization_level=optimization_level,
        )
    else:
        vllm_config = VllmConfig(
Jiayi Yan's avatar
Jiayi Yan committed
991
            compilation_config=compilation_config,
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
            optimization_level=optimization_level,
        )
    # Use the global optimization level defaults
    default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]

    # Verify pass_config defaults (nested under compilation_config)
    pass_config_dict = default_config["compilation_config"]["pass_config"]
    for pass_k, pass_v in pass_config_dict.items():
        actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
        expected = pass_v(vllm_config) if callable(pass_v) else pass_v
        assert actual == expected, (
            f"pass_config.{pass_k}: expected {expected}, got {actual}"
        )

    # Verify other compilation_config defaults
    compilation_config_dict = default_config["compilation_config"]
    for k, v in compilation_config_dict.items():
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        if k == "pass_config":
            continue
        actual = getattr(vllm_config.compilation_config, k)
        expected = v(vllm_config) if callable(v) else v
        # On platforms without static graph support, __post_init__ forces
        # cudagraph_mode to NONE; expect that instead of the level default.
        if k == "cudagraph_mode" and not current_platform.support_static_graph_mode():
            expected = CUDAGraphMode.NONE
        assert actual == expected, (
            f"compilation_config.{k}: expected {expected}, got {actual}"
        )
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039


def test_vllm_config_callable_defaults():
    """Test that callable defaults work in the config system.

    Verifies that lambdas in default configs can inspect VllmConfig properties
    (e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
    """
    config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)

    # Callable that checks if model exists
    has_model = lambda cfg: cfg.model_config is not None
    assert has_model(config_no_model) is False

    # Test with quantized model
    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    config_quantized = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_quantized = lambda cfg: (
1040
        cfg.model_config is not None and cfg.model_config.is_quantized
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    )
    assert enable_if_quantized(config_quantized) is True
    assert enable_if_quantized(config_no_model) is False

    # Test with MoE model
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    config_moe = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_sequential = lambda cfg: (
1051
        cfg.model_config is not None and not cfg.model_config.is_moe
1052
1053
1054
1055
1056
    )
    assert enable_if_sequential(config_moe) is False
    assert enable_if_sequential(config_quantized) is True


1057
1058
1059
1060
@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Explicit overrides may be force-overwritten without static graph support.",
)
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
def test_vllm_config_explicit_overrides():
    """Test that explicit property overrides work correctly with callable defaults.

    When users explicitly set configuration properties, those values
    take precedence over callable defaults, across different models and
    optimization levels.
    """
    from vllm.config.compilation import PassConfig

    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    regular_model = ModelConfig("Qwen/Qwen1.5-7B")

    # Explicit compilation mode override on O0 (where default is NONE)
    compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE

    # Explicit pass config flags to override defaults
1084
    pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
1085
1086
1087
1088
1089
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
1090
1091
    assert config.compilation_config.pass_config.eliminate_noops is True
    assert config.compilation_config.pass_config.fuse_attn_quant is True
1092
1093

    # Explicit cudagraph mode override on quantized model at O2
1094
    pass_config = PassConfig(enable_qk_norm_rope_fusion=True)
1095
1096
1097
1098
1099
1100
1101
1102
1103
    compilation_config = CompilationConfig(
        cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
    )
    config = VllmConfig(
        model_config=quantized_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
1104
    assert config.compilation_config.pass_config.enable_qk_norm_rope_fusion is True
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    # Mode should still use default for O2
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE

    # Different optimization levels with same model
    config_o0 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O0
    )
    config_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    assert config_o0.compilation_config.mode == CompilationMode.NONE
    assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert (
        config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Same optimization level across different model types
    config_moe_o2 = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    config_regular_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    config_quantized_o2 = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    # All should have same base compilation settings at O2
    assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert (
        config_moe_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )
    assert (
        config_regular_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Override one field but not others
1146
    pass_config = PassConfig(eliminate_noops=False)
1147
1148
1149
1150
1151
1152
1153
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        model_config=regular_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    # Explicit override should be respected
1154
    assert config.compilation_config.pass_config.eliminate_noops is False
1155
1156
1157
    # Other fields should still use defaults
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
1158
1159


1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
def test_fusion_pass_op_priority():
    """This test checks that custom op enablement & IR op priority
    correctly control default fusions"""

    # Default config, O2, rms_norm+quant fusion disabled
    cfg1 = VllmConfig()
    assert not cfg1.compilation_config.pass_config.fuse_norm_quant

    # rms_norm manually enabled, O1, rms_norm+quant fusion enabled
    cfg2 = VllmConfig(
        optimization_level=OptimizationLevel.O1,
        compilation_config=CompilationConfig(
            custom_ops=["+rms_norm"],
        ),
    )
    assert cfg2.compilation_config.pass_config.fuse_norm_quant

    # using custom kernel for RMSNorm via IR:
    # Note that vLLM IR only supports the non-residual rms_norm for now;
    # soon this will be resolved.
    cfg3 = VllmConfig(
        kernel_config=KernelConfig(
            ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
        )
    )
    assert cfg3.compilation_config.pass_config.fuse_norm_quant

    # block-fp8 model should enable quant_fp8 automatically
    cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
    assert "+quant_fp8" in cfg4.compilation_config.custom_ops
    assert cfg4.compilation_config.pass_config.fuse_norm_quant


1193
1194
1195
1196
1197
1198
1199
1200
1201
def test_scheduler_config_init():
    with pytest.raises(ValidationError):
        # Positional InitVars missing
        # (InitVars cannot have defaults otherwise they will become attributes)
        SchedulerConfig()

    with pytest.raises(AttributeError):
        # InitVar does not become an attribute
        print(SchedulerConfig.default_factory().max_model_len)
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244


@pytest.mark.parametrize(
    (
        "model_id",
        "data_parallel_size",
        "external_lb",
        "expected_needs_coordinator",
    ),
    [
        # Non-MoE model with DP=1 should not need coordinator
        ("facebook/opt-125m", 1, False, False),
        # Non-MoE model with DP>1 internal LB should need coordinator
        ("facebook/opt-125m", 2, False, True),
        # Non-MoE model with DP>1 external LB should not need coordinator
        ("facebook/opt-125m", 2, True, False),
        # MoE model with DP=1 should not need coordinator
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
        # MoE model with DP>1 internal LB should need both coordinator
        # and wave coordination
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
        # MoE model with DP>1 external LB needs coordinator for wave coordination
        # (wave coordination runs in coordinator process)
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
    ],
)
def test_needs_dp_coordination(
    model_id,
    data_parallel_size,
    external_lb,
    expected_needs_coordinator,
):
    """Test that DP coordinator and wave coordination are configured correctly."""
    from vllm.config import ParallelConfig

    model_config = ModelConfig(model_id)
    parallel_config = ParallelConfig(
        data_parallel_size=data_parallel_size,
        data_parallel_external_lb=external_lb,
    )
    vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)

    assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
1245
1246


1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
def test_renderer_num_workers_with_mm_cache():
    """Disallow renderer_num_workers > 1 when mm processor cache is enabled,
    since neither cache type is thread-safe."""
    mm_model = "Qwen/Qwen2-VL-2B-Instruct"

    # Should raise: multi-worker + cache enabled (default cache_gb=4)
    with pytest.raises(ValueError, match="renderer-num-workers"):
        ModelConfig(mm_model, renderer_num_workers=4)

    # Should raise: multi-worker + explicit cache size
    with pytest.raises(ValueError, match="renderer-num-workers"):
        ModelConfig(mm_model, renderer_num_workers=2, mm_processor_cache_gb=1.0)

    # Should pass: multi-worker + cache disabled
    config = ModelConfig(mm_model, renderer_num_workers=4, mm_processor_cache_gb=0)
    assert config.renderer_num_workers == 4

    # Should pass: single worker + cache enabled (default)
    config = ModelConfig(mm_model, renderer_num_workers=1)
    assert config.renderer_num_workers == 1


1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
def test_eagle_draft_model_config():
    """Test that EagleDraft model config is correctly set."""
    target_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
    )
    speculative_config = SpeculativeConfig(
        model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
        num_speculative_tokens=1,
        target_model_config=target_model_config,
        target_parallel_config=ParallelConfig(),
    )
    draft_model_config = speculative_config.draft_model_config
    assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_config.model_type == "eagle"
    assert draft_model_config.hf_text_config.model_type == "eagle"
    assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.architecture == "EagleLlamaForCausalLM"
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318


def test_ir_op_priority_default():
    """Test that IR op priority defaults are set correctly."""
    from vllm.config.kernel import IrOpPriorityConfig

    # Assert default is applied to ops
    priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
    assert priority_config.rms_norm == ["vllm_c", "native"]

    # Assert single ops override the default
    assert IrOpPriorityConfig.with_default(
        ["vllm_c", "native"], rms_norm=["oink", "native"]
    ) == IrOpPriorityConfig(rms_norm=["oink", "native"])


def test_ir_op_priority_str():
    """Test that passing a comma-delimited string works"""
    from vllm.config.kernel import IrOpPriorityConfig

    priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
    assert priority_config.rms_norm == ["vllm_c"]

    priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
    assert priority_config.rms_norm == ["vllm_c", "native"]

    priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
    assert priority_config.rms_norm == ["native", "vllm_c"]

    with pytest.raises(pydantic.ValidationError):
        # must be list of only strings
        priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])