"tests/kernels/quantization/test_marlin_gemm.py" did not exist on "0f41fbe5a370c0b87bb9a038be592c9272d46364"
test_config.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import MISSING, Field, asdict, dataclass, field
5

6
7
import pytest

8
from vllm.compilation.backends import VllmBackend
9
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
10
                         get_field)
11
12
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
13

14

15
16
17
18
19
20
21
22
23
24
25
26
def test_compile_config_repr_succeeds():
    # setup: VllmBackend mutates the config object
    config = VllmConfig()
    backend = VllmBackend(config)
    backend.configure_post_pass()

    # test that repr(config) succeeds
    val = repr(config)
    assert 'VllmConfig' in val
    assert 'inductor_passes' in val


27
28
29
30
31
@dataclass
class _TestConfigFields:
    a: int
    b: dict = field(default_factory=dict)
    c: str = "default"
32
33


34
def test_get_field():
35
    with pytest.raises(ValueError):
36
        get_field(_TestConfigFields, "a")
37

38
    b = get_field(_TestConfigFields, "b")
39
40
41
42
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

43
    c = get_field(_TestConfigFields, "c")
44
45
46
47
48
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


49
50
51
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
52
        ("distilbert/distilgpt2", "generate", "generate"),
53
        ("intfloat/multilingual-e5-small", "pooling", "embed"),
54
55
        ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
        ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
56
        ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
57
        ("openai/whisper-small", "transcription", "transcription"),
58
59
60
    ],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
61
62
63
64
65
66
67
68
69
70
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

71
    assert config.runner_type == expected_runner_type
72
73
74
75
    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
76
    ("Qwen/Qwen2.5-Math-RM-72B", "generate"),
77
78
79
80
81
82
83
84
85
86
87
88
89
90
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


91
92
93
94
95
96
97
98
99
100
101
102
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
103
104
        task="auto",
        tokenizer=model_id,
105
106
107
108
109
110
111
112
113
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

114
115
116
117
118
119
120
121

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
        "Qwen/Qwen1.5-7B",
122
123
        task="auto",
        tokenizer="Qwen/Qwen1.5-7B",
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
        "mistralai/Mistral-7B-v0.1",
140
141
        task="auto",
        tokenizer="mistralai/Mistral-7B-v0.1",
142
143
144
145
146
147
148
149
150
151
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
152
153
154
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


155
156
157
158
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
159
    model_config = ModelConfig(
160
161
162
163
164
165
166
167
168
169
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

170
    pooling_config = model_config._init_pooler_config()
171
    assert pooling_config is not None
172

173
174
    assert pooling_config.normalize
    assert pooling_config.pooling_type == PoolingType.MEAN.name
175
176
177
178
179
180


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
181
182
183
184
185
186
187
188
189
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               revision=None)

190
191
    override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
    model_config.override_pooler_config = override_pooler_config
192

193
    pooling_config = model_config._init_pooler_config()
194
    assert pooling_config is not None
195
    assert asdict(pooling_config) == asdict(override_pooler_config)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
    bge_model_config = ModelConfig(
        model="BAAI/bge-base-en-v1.5",
        task="auto",
        tokenizer="BAAI/bge-base-en-v1.5",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


218
def test_rope_customization():
219
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
220
    TEST_ROPE_THETA = 16_000_000.0
221
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
222
223
224

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
225
226
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
227
228
229
230
231
232
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
233
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
234
235
236
237
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
238
239
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
240
241
242
243
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
244
245
246
247
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
            "rope_theta": TEST_ROPE_THETA,
        },
248
249
250
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
251
252
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
253
254
    assert llama_model_config.max_model_len == 16384

255
256
    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
257
258
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
259
260
261
262
263
264
265
266
267
268
269
270
271
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
272
273
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
274
275
276
277
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
278
279
280
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
        },
281
282
283
284
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096
285
286


287
288
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Encoder Decoder models not supported on ROCm.")
289
290
291
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
    ("facebook/opt-125m", False),
    ("facebook/bart-base", True),
292
    ("meta-llama/Llama-3.2-1B-Instruct", False),
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    ("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
    ("facebook/opt-125m", False),
    ("Qwen/Qwen2-VL-2B-Instruct", True),
])
def test_uses_mrope(model_id, uses_mrope):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.uses_mrope == uses_mrope
325
326
327
328
329


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

330
    # When set generation_config to "vllm", the default generation config
331
332
333
334
335
336
337
338
    # will not be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
339
                               generation_config="vllm")
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config="auto")

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config="auto",
        override_generation_config=override_generation_config)

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

381
    # When generation_config is set to "vllm" and override_generation_config
382
383
384
385
386
387
388
389
390
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
391
        generation_config="vllm",
392
393
394
        override_generation_config=override_generation_config)

    assert model_config.get_diff_sampling_param() == override_generation_config
395
396
397
398
399
400
401
402
403
404
405
406
407


@pytest.mark.parametrize("pt_load_map_location", [
    "cuda",
    {
        "": "cuda"
    },
])
def test_load_config_pt_load_map_location(pt_load_map_location):
    load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
    config = VllmConfig(load_config=load_config)

    assert config.load_config.pt_load_map_location == pt_load_map_location
408
409
410
411
412
413
414


@pytest.mark.parametrize(
    ("model_id", "max_model_len", "expected_max_len", "should_raise"), [
        ("BAAI/bge-reranker-base", None, 512, False),
        ("BAAI/bge-reranker-base", 256, 256, False),
        ("BAAI/bge-reranker-base", 513, 512, True),
415
416
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False),
        ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True),
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    ])
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
                                should_raise):
    """Test get_and_verify_max_len with different configurations."""
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    if should_raise:
        with pytest.raises(ValueError):
            model_config.get_and_verify_max_len(max_model_len)
    else:
        actual_max_len = model_config.get_and_verify_max_len(max_model_len)
        assert actual_max_len == expected_max_len