test_config.py 16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import MISSING, Field, asdict, dataclass, field
5

6
import pytest
7
import os
8

9
from vllm.compilation.backends import VllmBackend
10
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
11
                         get_field)
12
13
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
14
from utils import models_path_prefix
15

16

17
18
19
20
21
22
23
24
25
26
27
28
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
    backend = VllmBackend(config)
    backend.configure_post_pass()

    # test that repr(config) succeeds
    val = repr(config)
    assert 'VllmConfig' in val
    assert 'inductor_passes' in val


29
30
31
32
33
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
34
35


36
def test_get_field():
37
    with pytest.raises(ValueError):
38
        get_field(_TestConfigFields, "a")
39

40
    b = get_field(_TestConfigFields, "b")
41
42
43
44
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

45
    c = get_field(_TestConfigFields, "c")
46
47
48
49
50
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


51
52
53
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
zhuwenwen's avatar
zhuwenwen committed
54
55
56
        (os.path.join(models_path_prefix, "distilbert/distilgpt2"), "generate", "generate"),
        (os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), "pooling", "embed"),
        (os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
zhuwenwen's avatar
zhuwenwen committed
57
        (os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "classify"),
zhuwenwen's avatar
zhuwenwen committed
58
        (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "reward"),
zhuwenwen's avatar
zhuwenwen committed
59
        (os.path.join(models_path_prefix, "openai/whisper-small"), "transcription", "transcription"),
60
61
62
    ],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
63
64
65
66
67
68
69
70
71
72
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

73
    assert config.runner_type == expected_runner_type
74
75
76
    assert config.task == expected_task


77
78
79
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
zhuwenwen's avatar
zhuwenwen committed
80
81
82
83
84
85
        (os.path.join(models_path_prefix, "distilbert/distilgpt2"), "pooling", "embed"),
        (os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), "pooling", "embed"),
        (os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
        (os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "classify"),
        (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "embed"),
        (os.path.join(models_path_prefix, "openai/whisper-small"), "pooling", "embed"),
86
87
88
89
90
91
92
    ],
)
def test_score_task(model_id, expected_runner_type, expected_task):
    config = ModelConfig(
        model_id,
        task="score",
        tokenizer=model_id,
93
94
95
96
97
98
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

99
    assert config.runner_type == expected_runner_type
100
101
102
103
    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
zhuwenwen's avatar
zhuwenwen committed
104
    (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "generate"),
105
106
107
108
109
110
111
112
113
114
115
116
117
118
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


119
MODEL_IDS_EXPECTED = [
120
121
122
    (os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"), 32768),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"), 4096),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.2"), 32768),
123
124
125
126
127
128
129
130
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
131
132
        task="auto",
        tokenizer=model_id,
133
134
135
136
137
138
139
140
141
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

142
143
144
145
146
147
148

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
149
        os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
150
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
151
        tokenizer=os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
167
        os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
168
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
169
        tokenizer=os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
170
171
172
173
174
175
176
177
178
179
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
180
181
182
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


183
184
185
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
186
    model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
187
    model_config = ModelConfig(
188
189
190
191
192
193
194
195
196
197
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

198
    pooling_config = model_config._init_pooler_config()
199
    assert pooling_config is not None
200

201
202
    assert pooling_config.normalize
    assert pooling_config.pooling_type == PoolingType.MEAN.name
203
204
205
206
207


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
208
    model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
209
210
211
212
213
214
215
216
217
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               revision=None)

218
219
    override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
    model_config.override_pooler_config = override_pooler_config
220

221
    pooling_config = model_config._init_pooler_config()
222
    assert pooling_config is not None
223
    assert asdict(pooling_config) == asdict(override_pooler_config)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
    bge_model_config = ModelConfig(
        model="BAAI/bge-base-en-v1.5",
        task="auto",
        tokenizer="BAAI/bge-base-en-v1.5",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


246
def test_rope_customization():
247
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
248
    TEST_ROPE_THETA = 16_000_000.0
249
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
250
251

    llama_model_config = ModelConfig(
252
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
253
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
254
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
255
256
257
258
259
260
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
261
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
262
263
264
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
265
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
266
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
267
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
268
269
270
271
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
272
273
274
275
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
            "rope_theta": TEST_ROPE_THETA,
        },
276
277
278
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
279
280
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
281
282
    assert llama_model_config.max_model_len == 16384

283
    longchat_model_config = ModelConfig(
284
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
285
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
286
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
287
288
289
290
291
292
293
294
295
296
297
298
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
299
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
300
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
301
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
302
303
304
305
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
306
307
308
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
        },
309
310
311
312
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096
313
314


315
316
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Encoder Decoder models not supported on ROCm.")
317
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
318
319
320
321
    (os.path.join(models_path_prefix, "facebook/opt-125m"), False),
    (os.path.join(models_path_prefix, "facebook/bart-base"), True),
    (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), False),
    (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision"), True),
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
338
339
    (os.path.join(models_path_prefix, "facebook/opt-125m"), False),
    (os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct"), True),
340
341
342
343
344
345
346
347
348
349
350
351
352
])
def test_uses_mrope(model_id, uses_mrope):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.uses_mrope == uses_mrope
353
354
355


def test_generation_config_loading():
356
    model_id = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
357

358
    # When set generation_config to "vllm", the default generation config
359
360
361
362
363
364
365
366
    # will not be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
367
                               generation_config="vllm")
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config="auto")

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config="auto",
        override_generation_config=override_generation_config)

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

409
    # When generation_config is set to "vllm" and override_generation_config
410
411
412
413
414
415
416
417
418
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
419
        generation_config="vllm",
420
421
422
        override_generation_config=override_generation_config)

    assert model_config.get_diff_sampling_param() == override_generation_config
423
424
425
426
427
428
429
430
431
432
433
434
435


@pytest.mark.parametrize("pt_load_map_location", [
    "cuda",
    {
        "": "cuda"
    },
])
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
436
437
438
439
440
441
442


@pytest.mark.parametrize(
    ("model_id", "max_model_len", "expected_max_len", "should_raise"), [
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
443
444
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    ])
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
                                should_raise):
    """Test get_and_verify_max_len with different configurations."""
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    if should_raise:
        with pytest.raises(ValueError):
            model_config.get_and_verify_max_len(max_model_len)
    else:
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len