test_config.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import MISSING, Field, asdict, dataclass, field
5

6
7
import pytest

8
from vllm.compilation.backends import VllmBackend
9
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
10
                         get_field)
11
12
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
13

14

15
16
17
18
19
20
21
22
23
24
25
26
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
    backend = VllmBackend(config)
    backend.configure_post_pass()

    # test that repr(config) succeeds
    val = repr(config)
    assert 'VllmConfig' in val
    assert 'inductor_passes' in val


27
28
29
30
31
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
32
33


34
def test_get_field():
35
    with pytest.raises(ValueError):
36
        get_field(_TestConfigFields, "a")
37

38
    b = get_field(_TestConfigFields, "b")
39
40
41
42
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

43
    c = get_field(_TestConfigFields, "c")
44
45
46
47
48
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


49
50
51
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
52
        ("distilbert/distilgpt2", "generate", "generate"),
53
        ("intfloat/multilingual-e5-small", "pooling", "embed"),
54
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
55
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
56
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
57
        ("openai/whisper-small", "transcription", "transcription"),
58
59
60
    ],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
61
62
63
64
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

    assert config.runner_type == expected_runner_type
    assert config.task == expected_task


@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
        ("distilbert/distilgpt2", "pooling", "embed"),
        ("intfloat/multilingual-e5-small", "pooling", "embed"),
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
        ("openai/whisper-small", "pooling", "embed"),
    ],
)
def test_score_task(model_id, expected_runner_type, expected_task):
    config = ModelConfig(
        model_id,
        task="score",
        tokenizer=model_id,
91
92
93
94
95
96
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

97
    assert config.runner_type == expected_runner_type
98
99
100
101
    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
102
    ("Qwen/Qwen2.5-Math-RM-72B", "generate"),
103
104
105
106
107
108
109
110
111
112
113
114
115
116
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


117
118
119
120
121
122
123
124
125
126
127
128
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
129
130
        task="auto",
        tokenizer=model_id,
131
132
133
134
135
136
137
138
139
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

140
141
142
143
144
145
146
147

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
        "Qwen/Qwen1.5-7B",
148
149
        task="auto",
        tokenizer="Qwen/Qwen1.5-7B",
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
        "mistralai/Mistral-7B-v0.1",
166
167
        task="auto",
        tokenizer="mistralai/Mistral-7B-v0.1",
168
169
170
171
172
173
174
175
176
177
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
178
179
180
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


181
182
183
184
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
185
    model_config = ModelConfig(
186
187
188
189
190
191
192
193
194
195
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

196
    pooling_config = model_config._init_pooler_config()
197
    assert pooling_config is not None
198

199
200
    assert pooling_config.normalize
    assert pooling_config.pooling_type == PoolingType.MEAN.name
201
202
203
204
205
206


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
207
208
209
210
211
212
213
214
215
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               revision=None)

216
217
    override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
    model_config.override_pooler_config = override_pooler_config
218

219
    pooling_config = model_config._init_pooler_config()
220
    assert pooling_config is not None
221
    assert asdict(pooling_config) == asdict(override_pooler_config)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
    bge_model_config = ModelConfig(
        model="BAAI/bge-base-en-v1.5",
        task="auto",
        tokenizer="BAAI/bge-base-en-v1.5",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


244
def test_rope_customization():
245
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
246
    TEST_ROPE_THETA = 16_000_000.0
247
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
248
249
250

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
251
252
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
253
254
255
256
257
258
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
259
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
260
261
262
263
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
264
265
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
266
267
268
269
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
270
271
272
273
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
            "rope_theta": TEST_ROPE_THETA,
        },
274
275
276
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
277
278
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
279
280
    assert llama_model_config.max_model_len == 16384

281
282
    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
283
284
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
285
286
287
288
289
290
291
292
293
294
295
296
297
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
298
299
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
300
301
302
303
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
304
305
306
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
        },
307
308
309
310
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096
311
312


313
314
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Encoder Decoder models not supported on ROCm.")
315
316
317
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
    ("facebook/opt-125m", False),
    ("facebook/bart-base", True),
318
    ("meta-llama/Llama-3.2-1B-Instruct", False),
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    ("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
    ("facebook/opt-125m", False),
    ("Qwen/Qwen2-VL-2B-Instruct", True),
])
def test_uses_mrope(model_id, uses_mrope):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.uses_mrope == uses_mrope
351
352
353
354
355


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

356
    # When set generation_config to "vllm", the default generation config
357
358
359
360
361
362
363
364
    # will not be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
365
                               generation_config="vllm")
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config="auto")

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config="auto",
        override_generation_config=override_generation_config)

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

407
    # When generation_config is set to "vllm" and override_generation_config
408
409
410
411
412
413
414
415
416
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
417
        generation_config="vllm",
418
419
420
        override_generation_config=override_generation_config)

    assert model_config.get_diff_sampling_param() == override_generation_config
421
422
423
424
425
426
427
428
429
430
431
432
433


@pytest.mark.parametrize("pt_load_map_location", [
    "cuda",
    {
        "": "cuda"
    },
])
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
434
435
436
437
438
439
440


@pytest.mark.parametrize(
    ("model_id", "max_model_len", "expected_max_len", "should_raise"), [
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
441
442
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    ])
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
                                should_raise):
    """Test get_and_verify_max_len with different configurations."""
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    if should_raise:
        with pytest.raises(ValueError):
            model_config.get_and_verify_max_len(max_model_len)
    else:
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len