test_config.py 46.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
import os
6
from dataclasses import MISSING, Field, asdict, dataclass, field
7
from types import SimpleNamespace
8
from unittest.mock import patch
9

10
import pydantic
11
import pytest
12
from pydantic import ValidationError
13

14
import vllm.config.vllm as vllm_config_module
15
from vllm.compilation.backends import VllmBackend
16
17
from vllm.config import (
    CompilationConfig,
18
    KernelConfig,
19
    ModelConfig,
20
    ParallelConfig,
21
    PoolerConfig,
22
    SchedulerConfig,
23
    SpeculativeConfig,
24
25
26
27
    VllmConfig,
    update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
28
from vllm.config.kernel import IrOpPriorityConfig
29
from vllm.config.load import LoadConfig
30
from vllm.config.utils import get_field
31
32
33
34
from vllm.config.vllm import (
    OPTIMIZATION_LEVEL_TO_CONFIG,
    OptimizationLevel,
)
35
from vllm.platforms import current_platform
36

37
38
DEVICE_TYPE = current_platform.device_type

39

40
41
42
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
43
44
    backend = VllmBackend(config)
    backend.configure_post_pass()
45
46
47

    # test that repr(config) succeeds
    val = repr(config)
48
49
    assert "VllmConfig" in val
    assert "inductor_passes" in val
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@pytest.mark.skip_global_cleanup
def test_with_hf_config_populates_missing_architectures_from_causal_lm_mapping(
    monkeypatch,
):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(model_type="mistral", architectures=None)

    updated = VllmConfig.with_hf_config(cfg, hf_config)

    assert updated.model_config.hf_config.architectures == ["MistralForCausalLM"]
    assert hf_config.architectures is None


@pytest.mark.skip_global_cleanup
def test_with_hf_config_preserves_explicit_architectures_override(monkeypatch):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(model_type="mistral", architectures=None)

    updated = VllmConfig.with_hf_config(
        cfg,
        hf_config,
        architectures=["Ministral3ForCausalLM"],
    )

    assert updated.model_config.hf_config.architectures == ["Ministral3ForCausalLM"]


@pytest.mark.skip_global_cleanup
def test_with_hf_config_leaves_unknown_model_type_without_architectures(
    monkeypatch,
):
    monkeypatch.setattr(
        vllm_config_module,
        "replace",
        lambda self, **kwargs: SimpleNamespace(**kwargs),
    )
    cfg = SimpleNamespace(
        model_config=SimpleNamespace(
            is_multimodal_model=False,
            hf_config=SimpleNamespace(),
            get_model_arch_config=lambda: "arch-config",
        )
    )
    hf_config = SimpleNamespace(
        model_type="not_a_real_model",
        architectures=None,
    )

    updated = VllmConfig.with_hf_config(cfg, hf_config)

    assert updated.model_config.hf_config.architectures is None


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
    cfg = VllmConfig(
        scheduler_config=SchedulerConfig(
            max_model_len=8192,
            is_encoder_decoder=False,
            async_scheduling=True,
        ),
        parallel_config=ParallelConfig(
            pipeline_parallel_size=2,
            distributed_executor_backend="mp",
            nnodes=2,
        ),
    )
    assert cfg.scheduler_config.async_scheduling is True


143
144
145
146
147
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
148
149


150
151
def test_get_field():
    b = get_field(_TestConfigFields, "b")
152
153
154
155
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

156
    c = get_field(_TestConfigFields, "c")
157
158
159
160
161
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


162
163
@dataclass
class _TestNestedConfig:
164
    a: _TestConfigFields = field(default_factory=lambda: _TestConfigFields(a=0))
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188


def test_update_config():
    # Simple update
    config1 = _TestConfigFields(a=0)
    new_config1 = update_config(config1, {"a": 42})
    assert new_config1.a == 42
    # Nonexistent field
    with pytest.raises(AssertionError):
        new_config1 = update_config(config1, {"nonexistent": 1})
    # Nested update with dataclass
    config2 = _TestNestedConfig()
    new_inner_config = _TestConfigFields(a=1, c="new_value")
    new_config2 = update_config(config2, {"a": new_inner_config})
    assert new_config2.a == new_inner_config
    # Nested update with dict
    config3 = _TestNestedConfig()
    new_config3 = update_config(config3, {"a": {"c": "new_value"}})
    assert new_config3.a.c == "new_value"
    # Nested update with invalid type
    with pytest.raises(AssertionError):
        new_config3 = update_config(config3, {"a": "new_value"})


189
190
191
192
193
194
195
196
197
198
199
200
201
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("distilbert/distilgpt2", "generate", "none"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "generate", "none"),
    ],
)
def test_auto_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="auto")
202
203

    assert config.runner_type == expected_runner_type
204
    assert config.convert_type == expected_convert_type
205
206
207


@pytest.mark.parametrize(
208
    ("model_id", "expected_runner_type", "expected_convert_type"),
209
    [
210
211
212
213
214
215
        ("distilbert/distilgpt2", "pooling", "embed"),
        ("intfloat/multilingual-e5-small", "pooling", "none"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
        ("openai/whisper-small", "pooling", "embed"),
216
217
    ],
)
218
219
def test_pooling_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="pooling")
220
221

    assert config.runner_type == expected_runner_type
222
    assert config.convert_type == expected_convert_type
223
224


225
226
227
228
229
230
231
232
233
234
235
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_convert_type"),
    [
        ("Qwen/Qwen2.5-1.5B-Instruct", "draft", "none"),
    ],
)
def test_draft_runner(model_id, expected_runner_type, expected_convert_type):
    config = ModelConfig(model_id, runner="draft")

    assert config.runner_type == expected_runner_type
    assert config.convert_type == expected_convert_type
236
237


238
239
240
241
242
243
244
245
246
247
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
248
    model_config = ModelConfig(model_id, disable_sliding_window=True)
249
250
    assert model_config.max_model_len == expected

251

252
253
254
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
255
256
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
257
    model_config = ModelConfig(model_id)
258

259
    assert model_config.pooler_config is not None
260
    assert model_config.pooler_config.use_activation
261
262
    assert model_config.pooler_config.seq_pooling_type == "MEAN"
    assert model_config.pooler_config.tok_pooling_type == "ALL"
263
264


265
266
267
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
268
269
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
270
    pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
271
    model_config = ModelConfig(model_id, pooler_config=pooler_config)
272

273
    assert asdict(model_config.pooler_config) == asdict(pooler_config)
274
275


276
277
278
279
280
@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
        ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"),  # LLM
        ("intfloat/e5-small", "CLS", "MEAN"),  # BertModel
281
282
283
284
285
286
287
288
289
290
291
    ],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
    model_config = ModelConfig(model_id)
    assert model_config._model_info.default_seq_pooling_type == default_pooling_type
    assert model_config.pooler_config.seq_pooling_type == pooling_type


@pytest.mark.parametrize(
    ("model_id", "default_pooling_type", "pooling_type"),
    [
292
        ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"),  # reward
293
294
295
        ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"),  # step reward
    ],
)
296
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
297
    model_config = ModelConfig(model_id)
298
299
    assert model_config._model_info.default_tok_pooling_type == default_pooling_type
    assert model_config.pooler_config.tok_pooling_type == pooling_type
300
301


302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
@pytest.mark.parametrize(
    ("model_id", "expected_is_moe_model"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
        ("RedHatAI/Llama-3.2-1B-FP8", False),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
    ],
)
def test_moe_model_detection(model_id, expected_is_moe_model):
    model_config = ModelConfig(model_id)
317
318
    # Just check that is_moe field exists and is a boolean
    assert model_config.is_moe == expected_is_moe_model
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335


@pytest.mark.parametrize(
    ("model_id", "quantized"),
    [
        ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
        ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
        ("RedHatAI/Llama-3.2-1B-FP8", True),
        ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
        ("RedHatAI/gpt-oss-20b", True),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
        ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
    ],
)
def test_is_quantized(model_id, quantized):
    model_config = ModelConfig(model_id)
    # Just check that quantized field exists and is a boolean
336
    assert model_config.is_quantized == quantized
337
338


339
340
341
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
342
def test_get_bert_tokenization_sentence_transformer_config():
343
344
    model_id = "BAAI/bge-base-en-v1.5"
    bge_model_config = ModelConfig(model_id)
345
346
347
348
349
350
351

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


352
def test_rope_customization():
353
354
355
356
357
358
359
    TEST_ROPE_PARAMETERS = {
        "rope_theta": 16_000_000.0,
        "rope_type": "dynamic",
        "factor": 2.0,
    }
    LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
    LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
360

361
    llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
362
363
364
365
    assert (
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == LLAMA_ROPE_PARAMETERS
    )
366
367
368
369
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
370
        hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
371
    )
372
    assert (
373
374
        getattr(llama_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
375
    )
376
377
    assert llama_model_config.max_model_len == 16384

378
    longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
379
    # Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
380
    assert all(
381
382
        longchat_model_config.hf_config.rope_parameters.get(key) == value
        for key, value in LONGCHAT_ROPE_PARAMETERS.items()
383
    )
384
385
386
387
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
388
        hf_overrides={
389
            "rope_parameters": TEST_ROPE_PARAMETERS,
390
        },
391
    )
392
    assert (
393
394
        getattr(longchat_model_config.hf_config, "rope_parameters", None)
        == TEST_ROPE_PARAMETERS
395
    )
396
    assert longchat_model_config.max_model_len == 4096
397
398


399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def test_nested_hf_overrides():
    """Test that nested hf_overrides work correctly."""
    # Test with a model that has text_config
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 1024,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 1024

    # Test with deeply nested overrides
    model_config = ModelConfig(
        "Qwen/Qwen2-VL-2B-Instruct",
        hf_overrides={
            "text_config": {
                "hidden_size": 2048,
                "num_attention_heads": 16,
            },
            "vision_config": {
                "hidden_size": 512,
            },
        },
    )
    assert model_config.hf_config.text_config.hidden_size == 2048
    assert model_config.hf_config.text_config.num_attention_heads == 16
    assert model_config.hf_config.vision_config.hidden_size == 512


430
431
432
433
434
435
436
437
438
439
440
@pytest.mark.skipif(
    current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
)
@pytest.mark.parametrize(
    ("model_id", "is_encoder_decoder"),
    [
        ("facebook/opt-125m", False),
        ("openai/whisper-tiny", True),
        ("meta-llama/Llama-3.2-1B-Instruct", False),
    ],
)
441
def test_is_encoder_decoder(model_id, is_encoder_decoder):
442
    config = ModelConfig(model_id)
443
444
445
446

    assert config.is_encoder_decoder == is_encoder_decoder


447
448
449
450
451
452
453
@pytest.mark.parametrize(
    ("model_id", "uses_mrope"),
    [
        ("facebook/opt-125m", False),
        ("Qwen/Qwen2-VL-2B-Instruct", True),
    ],
)
454
def test_uses_mrope(model_id, uses_mrope):
455
    config = ModelConfig(model_id)
456
457

    assert config.uses_mrope == uses_mrope
458
459
460
461
462


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

463
    # When set generation_config to "vllm", the default generation config
464
    # will not be loaded.
465
    model_config = ModelConfig(model_id, generation_config="vllm")
466
467
468
469
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
470
    model_config = ModelConfig(model_id, generation_config="auto")
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        generation_config="auto",
487
488
        override_generation_config=override_generation_config,
    )
489
490
491
492
493
494

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

495
    # When generation_config is set to "vllm" and override_generation_config
496
497
498
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
499
        generation_config="vllm",
500
501
        override_generation_config=override_generation_config,
    )
502
503

    assert model_config.get_diff_sampling_param() == override_generation_config
504
505


506
507
508
@pytest.mark.parametrize(
    "pt_load_map_location",
    [
509
510
        DEVICE_TYPE,
        {"": DEVICE_TYPE},
511
512
    ],
)
513
514
515
516
517
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
518
519
520


@pytest.mark.parametrize(
521
522
    ("model_id", "max_model_len", "expected_max_len", "should_raise"),
    [
523
524
525
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
526
527
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
528
529
    ],
)
530
def test_get_and_verify_max_len(
531
532
    model_id, max_model_len, expected_max_len, should_raise
):
533
    """Test get_and_verify_max_len with different configurations."""
534
    model_config = ModelConfig(model_id)
535
536
537

    if should_raise:
        with pytest.raises(ValueError):
538
            model_config.get_and_verify_max_len(max_model_len)
539
    else:
540
541
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len
542
543


544
545
class MockConfig:
    """Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
546

547
    def __init__(self, model: str, tokenizer: str):
548
        self.model = model
549
550
        self.tokenizer = tokenizer
        self.model_weights = None
551
552


553
554
555
556
557
558
559
560
@pytest.mark.parametrize(
    "s3_url",
    [
        "s3://example-bucket-1/model/",
        "s3://example-bucket-2/model/",
    ],
)
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
561
562
563
564
565
566
567
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
    """Test that S3 URLs create deterministic local directories for model and
    tokenizer."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    # Create first mock and run the method
568
569
    config1 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
570
571

    # Check that model and tokenizer point to existing directories
572
573
    assert os.path.exists(config1.model), (
        f"Model directory does not exist: {config1.model}"
574
    )
575
576
    assert os.path.isdir(config1.model), (
        f"Model path is not a directory: {config1.model}"
577
    )
578
579
    assert os.path.exists(config1.tokenizer), (
        f"Tokenizer directory does not exist: {config1.tokenizer}"
580
    )
581
582
    assert os.path.isdir(config1.tokenizer), (
        f"Tokenizer path is not a directory: {config1.tokenizer}"
583
    )
584
585

    # Verify that the paths are different from the original S3 URL
586
587
    assert config1.model != s3_url, "Model path should be converted to local directory"
    assert config1.tokenizer != s3_url, (
588
589
        "Tokenizer path should be converted to local directory"
    )
590
591

    # Store the original paths
592
593
    created_model_dir = config1.model
    create_tokenizer_dir = config1.tokenizer
594
595

    # Create a new mock and run the method with the same S3 URL
596
597
    config2 = MockConfig(model=s3_url, tokenizer=s3_url)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
598
599

    # Check that the new directories exist
600
601
    assert os.path.exists(config2.model), (
        f"Model directory does not exist: {config2.model}"
602
    )
603
604
    assert os.path.isdir(config2.model), (
        f"Model path is not a directory: {config2.model}"
605
    )
606
607
    assert os.path.exists(config2.tokenizer), (
        f"Tokenizer directory does not exist: {config2.tokenizer}"
608
    )
609
610
    assert os.path.isdir(config2.tokenizer), (
        f"Tokenizer path is not a directory: {config2.tokenizer}"
611
    )
612
613

    # Verify that the paths are deterministic (same as before)
614
    assert config2.model == created_model_dir, (
615
        f"Model paths are not deterministic. "
616
        f"Original: {created_model_dir}, New: {config2.model}"
617
    )
618
    assert config2.tokenizer == create_tokenizer_dir, (
619
        f"Tokenizer paths are not deterministic. "
620
        f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}"
621
    )
622
623


624
@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files")
625
626
627
628
629
630
631
632
633
def test_s3_url_different_models_create_different_directories(mock_pull_files):
    """Test that different S3 URLs create different local directories."""
    # Mock pull_files to avoid actually downloading files during tests
    mock_pull_files.return_value = None

    s3_url1 = "s3://example-bucket-1/model/"
    s3_url2 = "s3://example-bucket-2/model/"

    # Create mocks with different S3 URLs and run the method
634
635
    config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
636

637
638
    config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
    ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
639
640

    # Verify that different URLs produce different directories
641
    assert config1.model != config2.model, (
642
        f"Different S3 URLs should create different model directories. "
643
        f"URL1 model: {config1.model}, URL2 model: {config2.model}"
644
    )
645
    assert config1.tokenizer != config2.tokenizer, (
646
        f"Different S3 URLs should create different tokenizer directories. "
647
648
        f"URL1 tokenizer: {config1.tokenizer}, "
        f"URL2 tokenizer: {config2.tokenizer}"
649
    )
650
651

    # Verify that both sets of directories exist
652
653
654
655
    assert os.path.exists(config1.model) and os.path.isdir(config1.model)
    assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
    assert os.path.exists(config2.model) and os.path.isdir(config2.model)
    assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
656
657


658
659
660
661
662
663
664
665
@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
666
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
667
668
669
670
671
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
672
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
673
674
675
676
677
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
678
            "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.",  # noqa: E501
679
680
681
682
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
683
            True,
684
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
685
686
687
688
689
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
690
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
691
692
693
694
695
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
696
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
697
698
699
700
701
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
702
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
703
704
705
706
707
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
708
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
709
710
711
712
713
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
714
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
715
716
717
718
719
720
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
721
            "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.",  # noqa: E501
722
723
724
725
726
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
727
            "Pooling models with bidirectional attn do not support chunked prefill.",  # noqa: E501
728
729
730
731
732
733
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
734
            "Generative models support chunked prefill.",  # noqa: E501
735
736
737
738
739
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            True,
740
            "Generative models support chunked prefill.",  # noqa: E501
741
742
743
744
745
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            True,
746
            "Generative models support chunked prefill.",  # noqa: E501
747
748
749
750
751
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            True,
752
            "Generative models support chunked prefill.",  # noqa: E501
753
754
755
756
757
758
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
759
            "Encoder decoder models do not support chunked prefill.",  # noqa: E501
760
761
762
763
764
765
766
767
768
769
770
771
        ),
    ],
)
def test_is_chunked_prefill_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
772
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
773
774
775
776
777
778
779
780
781
782
783
784
        assert model_config.is_chunked_prefill_supported == expected_result
    assert reason in caplog_vllm.text


@pytest.mark.parametrize(
    ("model_id", "expected_attn_type", "expected_result", "reason"),
    [
        # pooling models
        (
            "jason9693/Qwen2.5-1.5B-apeach",
            "decoder",
            True,
785
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
786
787
788
789
790
        ),
        (
            "Qwen/Qwen3-Embedding-0.6B",
            "decoder",
            True,
791
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
792
793
794
795
796
        ),
        (
            "Qwen/Qwen2.5-Math-PRM-7B",
            "decoder",
            False,
797
            "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.",  # noqa: E501
798
799
800
801
        ),
        (
            "internlm/internlm2-1_8b-reward",
            "decoder",
802
            True,
803
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
804
805
806
807
808
        ),
        (
            "BAAI/bge-base-en",
            "encoder_only",
            False,
809
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
810
811
812
813
814
        ),
        (
            "boltuix/NeuroBERT-NER",
            "encoder_only",
            False,
815
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
816
817
818
819
820
        ),
        (
            "papluca/xlm-roberta-base-language-detection",
            "encoder_only",
            False,
821
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
822
823
824
825
826
        ),
        (
            "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
            "encoder_only",
            False,
827
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
828
829
830
831
832
        ),
        (
            "intfloat/e5-small",
            "encoder_only",
            False,
833
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
834
835
836
837
838
839
        ),
        # multimodal models
        (
            "openai/clip-vit-base-patch32",
            "decoder",
            True,
840
            "Pooling models with causal attn and LAST/ALL pooling support prefix caching.",  # noqa: E501
841
842
843
844
845
        ),
        (
            "google/siglip-base-patch16-224",
            "encoder_only",
            False,
846
            "Pooling models with bidirectional attn do not support prefix caching.",  # noqa: E501
847
848
849
850
851
852
        ),
        # generate models
        (
            "Qwen/Qwen3-0.6B",
            "decoder",
            True,
853
            "Generative models support prefix caching.",  # noqa: E501
854
855
856
857
858
        ),
        (
            "Qwen/Qwen3-Next-80B-A3B-Instruct",
            "hybrid",
            False,
859
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
860
861
862
863
864
        ),
        (
            "ibm-granite/granite-4.0-h-small",
            "hybrid",
            False,
865
            "Hybrid models do not support prefix caching since the feature is still experimental.",  # noqa: E501
866
867
868
869
870
        ),
        (
            "state-spaces/mamba-130m-hf",
            "attention_free",
            False,
871
            "Attention free models do not support prefix caching since the feature is still experimental.",  # noqa: E501
872
873
874
875
876
877
        ),
        # encoder_decoder models
        (
            "openai/whisper-small",
            "encoder_decoder",
            False,
878
            "Encoder decoder models do not support prefix caching.",  # noqa: E501
879
880
881
882
883
884
885
886
887
888
889
890
        ),
    ],
)
def test_is_prefix_caching_supported(
    model_id: str,
    expected_attn_type: str,
    expected_result: bool,
    reason: str,
    caplog_vllm,
):
    model_config = ModelConfig(model_id, trust_remote_code=True)
    assert model_config.attn_type == expected_attn_type
891
    with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"):
892
893
894
895
        assert model_config.is_prefix_caching_supported == expected_result
    assert reason in caplog_vllm.text


896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
@pytest.mark.parametrize(
    ("backend", "custom_ops", "expected"),
    [
        ("eager", [], True),
        ("eager", ["+fused_layernorm"], True),
        ("eager", ["all", "-fused_layernorm"], False),
        ("inductor", [], False),
        ("inductor", ["none", "+fused_layernorm"], True),
        ("inductor", ["none", "-fused_layernorm"], False),
    ],
)
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
    """Test that is_custom_op_enabled works correctly."""
    config = VllmConfig(
        compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
    )
    assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected


def test_vllm_config_defaults_are_none():
    """Verify that optimization-level defaults are None when not set by user."""
    # Test all optimization levels to ensure defaults work correctly
    for opt_level in OptimizationLevel:
        config = object.__new__(VllmConfig)
        config.compilation_config = CompilationConfig()
        config.optimization_level = opt_level
        config.model_config = None

        # Use the global optimization level defaults
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]

        # Verify that all pass_config values are None before defaults are applied
        for pass_k in default_config["compilation_config"]["pass_config"]:
            assert getattr(config.compilation_config.pass_config, pass_k) is None

        # Verify that other config values are None before defaults are applied
        for k in default_config["compilation_config"]:
            if k != "pass_config":
                assert getattr(config.compilation_config, k) is None


@pytest.mark.parametrize(
Jiayi Yan's avatar
Jiayi Yan committed
938
    ("model_id", "compilation_config", "optimization_level"),
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
    [
        (
            None,
            CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O0,
        ),
        (None, CompilationConfig(), OptimizationLevel.O0),
        (None, CompilationConfig(), OptimizationLevel.O1),
        (None, CompilationConfig(), OptimizationLevel.O2),
        (None, CompilationConfig(), OptimizationLevel.O3),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O0,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O1,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O2,
        ),
        (
            "RedHatAI/Qwen3-8B-speculator.eagle3",
            CompilationConfig(),
            OptimizationLevel.O3,
        ),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
        ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
    ],
)
Jiayi Yan's avatar
Jiayi Yan committed
980
def test_vllm_config_defaults(model_id, compilation_config, optimization_level):
981
982
983
984
985
986
987
    """Test that optimization-level defaults are correctly applied."""

    model_config = None
    if model_id is not None:
        model_config = ModelConfig(model_id)
        vllm_config = VllmConfig(
            model_config=model_config,
Jiayi Yan's avatar
Jiayi Yan committed
988
            compilation_config=compilation_config,
989
990
991
992
            optimization_level=optimization_level,
        )
    else:
        vllm_config = VllmConfig(
Jiayi Yan's avatar
Jiayi Yan committed
993
            compilation_config=compilation_config,
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            optimization_level=optimization_level,
        )
    # Use the global optimization level defaults
    default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]

    # Verify pass_config defaults (nested under compilation_config)
    pass_config_dict = default_config["compilation_config"]["pass_config"]
    for pass_k, pass_v in pass_config_dict.items():
        actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
        expected = pass_v(vllm_config) if callable(pass_v) else pass_v
        assert actual == expected, (
            f"pass_config.{pass_k}: expected {expected}, got {actual}"
        )

    # Verify other compilation_config defaults
    compilation_config_dict = default_config["compilation_config"]
    for k, v in compilation_config_dict.items():
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        if k == "pass_config":
            continue
        actual = getattr(vllm_config.compilation_config, k)
        expected = v(vllm_config) if callable(v) else v
        # On platforms without static graph support, __post_init__ forces
        # cudagraph_mode to NONE; expect that instead of the level default.
        if k == "cudagraph_mode" and not current_platform.support_static_graph_mode():
            expected = CUDAGraphMode.NONE
        assert actual == expected, (
            f"compilation_config.{k}: expected {expected}, got {actual}"
        )
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041


def test_vllm_config_callable_defaults():
    """Test that callable defaults work in the config system.

    Verifies that lambdas in default configs can inspect VllmConfig properties
    (e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
    """
    config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)

    # Callable that checks if model exists
    has_model = lambda cfg: cfg.model_config is not None
    assert has_model(config_no_model) is False

    # Test with quantized model
    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    config_quantized = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_quantized = lambda cfg: (
1042
        cfg.model_config is not None and cfg.model_config.is_quantized
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    )
    assert enable_if_quantized(config_quantized) is True
    assert enable_if_quantized(config_no_model) is False

    # Test with MoE model
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    config_moe = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    enable_if_sequential = lambda cfg: (
1053
        cfg.model_config is not None and not cfg.model_config.is_moe
1054
1055
1056
1057
1058
    )
    assert enable_if_sequential(config_moe) is False
    assert enable_if_sequential(config_quantized) is True


1059
1060
1061
1062
@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Explicit overrides may be force-overwritten without static graph support.",
)
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
def test_vllm_config_explicit_overrides():
    """Test that explicit property overrides work correctly with callable defaults.

    When users explicitly set configuration properties, those values
    take precedence over callable defaults, across different models and
    optimization levels.
    """
    from vllm.config.compilation import PassConfig

    quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
    moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
    regular_model = ModelConfig("Qwen/Qwen1.5-7B")

    # Explicit compilation mode override on O0 (where default is NONE)
    compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE

    # Explicit pass config flags to override defaults
1086
    pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
1087
1088
1089
1090
1091
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        optimization_level=OptimizationLevel.O0,
        compilation_config=compilation_config,
    )
1092
1093
    assert config.compilation_config.pass_config.eliminate_noops is True
    assert config.compilation_config.pass_config.fuse_attn_quant is True
1094
1095

    # Explicit cudagraph mode override on quantized model at O2
1096
    pass_config = PassConfig(enable_qk_norm_rope_fusion=True)
1097
1098
1099
1100
1101
1102
1103
1104
1105
    compilation_config = CompilationConfig(
        cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
    )
    config = VllmConfig(
        model_config=quantized_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
1106
    assert config.compilation_config.pass_config.enable_qk_norm_rope_fusion is True
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    # Mode should still use default for O2
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE

    # Different optimization levels with same model
    config_o0 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O0
    )
    config_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    assert config_o0.compilation_config.mode == CompilationMode.NONE
    assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert (
        config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Same optimization level across different model types
    config_moe_o2 = VllmConfig(
        model_config=moe_model, optimization_level=OptimizationLevel.O2
    )
    config_regular_o2 = VllmConfig(
        model_config=regular_model, optimization_level=OptimizationLevel.O2
    )
    config_quantized_o2 = VllmConfig(
        model_config=quantized_model, optimization_level=OptimizationLevel.O2
    )
    # All should have same base compilation settings at O2
    assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert (
        config_moe_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )
    assert (
        config_regular_o2.compilation_config.cudagraph_mode
        == CUDAGraphMode.FULL_AND_PIECEWISE
    )

    # Override one field but not others
1148
    pass_config = PassConfig(eliminate_noops=False)
1149
1150
1151
1152
1153
1154
1155
    compilation_config = CompilationConfig(pass_config=pass_config)
    config = VllmConfig(
        model_config=regular_model,
        optimization_level=OptimizationLevel.O2,
        compilation_config=compilation_config,
    )
    # Explicit override should be respected
1156
    assert config.compilation_config.pass_config.eliminate_noops is False
1157
1158
1159
    # Other fields should still use defaults
    assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
1160
1161


1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
def test_fusion_pass_op_priority():
    """This test checks that custom op enablement & IR op priority
    correctly control default fusions"""

    # Default config, O2, rms_norm+quant fusion disabled
    cfg1 = VllmConfig()
    assert not cfg1.compilation_config.pass_config.fuse_norm_quant

    # rms_norm manually enabled, O1, rms_norm+quant fusion enabled
    cfg2 = VllmConfig(
        optimization_level=OptimizationLevel.O1,
        compilation_config=CompilationConfig(
            custom_ops=["+rms_norm"],
        ),
    )
    assert cfg2.compilation_config.pass_config.fuse_norm_quant

    # using custom kernel for RMSNorm via IR:
    # Note that vLLM IR only supports the non-residual rms_norm for now;
    # soon this will be resolved.
    cfg3 = VllmConfig(
        kernel_config=KernelConfig(
            ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
        )
    )
    assert cfg3.compilation_config.pass_config.fuse_norm_quant

    # block-fp8 model should enable quant_fp8 automatically
    cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
    assert "+quant_fp8" in cfg4.compilation_config.custom_ops
    assert cfg4.compilation_config.pass_config.fuse_norm_quant


1195
1196
1197
1198
1199
1200
1201
1202
1203
def test_scheduler_config_init():
    with pytest.raises(ValidationError):
        # Positional InitVars missing
        # (InitVars cannot have defaults otherwise they will become attributes)
        SchedulerConfig()

    with pytest.raises(AttributeError):
        # InitVar does not become an attribute
        print(SchedulerConfig.default_factory().max_model_len)
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246


@pytest.mark.parametrize(
    (
        "model_id",
        "data_parallel_size",
        "external_lb",
        "expected_needs_coordinator",
    ),
    [
        # Non-MoE model with DP=1 should not need coordinator
        ("facebook/opt-125m", 1, False, False),
        # Non-MoE model with DP>1 internal LB should need coordinator
        ("facebook/opt-125m", 2, False, True),
        # Non-MoE model with DP>1 external LB should not need coordinator
        ("facebook/opt-125m", 2, True, False),
        # MoE model with DP=1 should not need coordinator
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
        # MoE model with DP>1 internal LB should need both coordinator
        # and wave coordination
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
        # MoE model with DP>1 external LB needs coordinator for wave coordination
        # (wave coordination runs in coordinator process)
        ("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
    ],
)
def test_needs_dp_coordination(
    model_id,
    data_parallel_size,
    external_lb,
    expected_needs_coordinator,
):
    """Test that DP coordinator and wave coordination are configured correctly."""
    from vllm.config import ParallelConfig

    model_config = ModelConfig(model_id)
    parallel_config = ParallelConfig(
        data_parallel_size=data_parallel_size,
        data_parallel_external_lb=external_lb,
    )
    vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)

    assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
1247
1248


1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
def test_renderer_num_workers_with_mm_cache():
    """Disallow renderer_num_workers > 1 when mm processor cache is enabled,
    since neither cache type is thread-safe."""
    mm_model = "Qwen/Qwen2-VL-2B-Instruct"

    # Should raise: multi-worker + cache enabled (default cache_gb=4)
    with pytest.raises(ValueError, match="renderer-num-workers"):
        ModelConfig(mm_model, renderer_num_workers=4)

    # Should raise: multi-worker + explicit cache size
    with pytest.raises(ValueError, match="renderer-num-workers"):
        ModelConfig(mm_model, renderer_num_workers=2, mm_processor_cache_gb=1.0)

    # Should pass: multi-worker + cache disabled
    config = ModelConfig(mm_model, renderer_num_workers=4, mm_processor_cache_gb=0)
    assert config.renderer_num_workers == 4

    # Should pass: single worker + cache enabled (default)
    config = ModelConfig(mm_model, renderer_num_workers=1)
    assert config.renderer_num_workers == 1


1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
def test_eagle_draft_model_config():
    """Test that EagleDraft model config is correctly set."""
    target_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
    )
    speculative_config = SpeculativeConfig(
        model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
        num_speculative_tokens=1,
        target_model_config=target_model_config,
        target_parallel_config=ParallelConfig(),
    )
    draft_model_config = speculative_config.draft_model_config
    assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.hf_config.model_type == "eagle"
    assert draft_model_config.hf_text_config.model_type == "eagle"
    assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
    assert draft_model_config.architecture == "EagleLlamaForCausalLM"
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320


def test_ir_op_priority_default():
    """Test that IR op priority defaults are set correctly."""
    from vllm.config.kernel import IrOpPriorityConfig

    # Assert default is applied to ops
    priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
    assert priority_config.rms_norm == ["vllm_c", "native"]

    # Assert single ops override the default
    assert IrOpPriorityConfig.with_default(
        ["vllm_c", "native"], rms_norm=["oink", "native"]
    ) == IrOpPriorityConfig(rms_norm=["oink", "native"])


def test_ir_op_priority_str():
    """Test that passing a comma-delimited string works"""
    from vllm.config.kernel import IrOpPriorityConfig

    priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
    assert priority_config.rms_norm == ["vllm_c"]

    priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
    assert priority_config.rms_norm == ["vllm_c", "native"]

    priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
    assert priority_config.rms_norm == ["native", "vllm_c"]

    with pytest.raises(pydantic.ValidationError):
        # must be list of only strings
        priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])