test_config.py 13 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import MISSING, Field, asdict, dataclass, field
4

5
import pytest
6
import os
7

8
from vllm.config import ModelConfig, PoolerConfig, get_field
9
10
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
11
from utils import models_path_prefix
12

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def test_get_field():

    @dataclass
    class TestConfig:
        a: int
        b: dict = field(default_factory=dict)
        c: str = "default"

    with pytest.raises(ValueError):
        get_field(TestConfig, "a")

    b = get_field(TestConfig, "b")
    assert isinstance(b, Field)
    assert b.default is MISSING
    assert b.default_factory is dict

    c = get_field(TestConfig, "c")
    assert isinstance(c, Field)
    assert c.default == "default"
    assert c.default_factory is MISSING


36
37
38
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
zhuwenwen's avatar
zhuwenwen committed
39
40
41
42
        (os.path.join(models_path_prefix, "distilbert/distilgpt2"), "generate", "generate"),
        (os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), "pooling", "embed"),
        (os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
        (os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "score"),
zhuwenwen's avatar
zhuwenwen committed
43
        (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "reward"),
zhuwenwen's avatar
zhuwenwen committed
44
        (os.path.join(models_path_prefix, "openai/whisper-small"), "transcription", "transcription"),
45
46
47
    ],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
48
49
50
51
52
53
54
55
56
57
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

58
    assert config.runner_type == expected_runner_type
59
60
61
62
    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
zhuwenwen's avatar
zhuwenwen committed
63
    (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "generate"),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


78
MODEL_IDS_EXPECTED = [
79
80
81
    (os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"), 32768),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"), 4096),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.2"), 32768),
82
83
84
85
86
87
88
89
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
90
91
        task="auto",
        tokenizer=model_id,
92
93
94
95
96
97
98
99
100
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

101
102
103
104
105
106
107

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
108
        os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
109
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
110
        tokenizer=os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
126
        os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
127
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
128
        tokenizer=os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
129
130
131
132
133
134
135
136
137
138
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
139
140
141
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


142
143
144
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
145
    model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
146
    model_config = ModelConfig(
147
148
149
150
151
152
153
154
155
156
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

157
158
    pooling_config = model_config._init_pooler_config(None)
    assert pooling_config is not None
159

160
161
    assert pooling_config.normalize
    assert pooling_config.pooling_type == PoolingType.MEAN.name
162
163
164
165
166


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
167
    model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               revision=None)

    override_config = PoolerConfig(pooling_type='CLS', normalize=True)

    pooling_config = model_config._init_pooler_config(override_config)
    assert pooling_config is not None
    assert asdict(pooling_config) == asdict(override_config)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
    bge_model_config = ModelConfig(
        model="BAAI/bge-base-en-v1.5",
        task="auto",
        tokenizer="BAAI/bge-base-en-v1.5",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


204
def test_rope_customization():
205
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
206
    TEST_ROPE_THETA = 16_000_000.0
207
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
208
209

    llama_model_config = ModelConfig(
210
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
211
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
212
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
213
214
215
216
217
218
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
219
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
220
221
222
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
223
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
224
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
225
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
226
227
228
229
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
230
231
232
233
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
            "rope_theta": TEST_ROPE_THETA,
        },
234
235
236
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
237
238
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
239
240
    assert llama_model_config.max_model_len == 16384

241
    longchat_model_config = ModelConfig(
242
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
243
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
244
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
245
246
247
248
249
250
251
252
253
254
255
256
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
257
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
258
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
259
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
260
261
262
263
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
264
265
266
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
        },
267
268
269
270
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096
271
272


273
274
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Encoder Decoder models not supported on ROCm.")
275
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
276
277
278
279
    (os.path.join(models_path_prefix, "facebook/opt-125m"), False),
    (os.path.join(models_path_prefix, "facebook/bart-base"), True),
    (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), False),
    (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision"), True),
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
296
297
    (os.path.join(models_path_prefix, "facebook/opt-125m"), False),
    (os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct"), True),
298
299
300
301
302
303
304
305
306
307
308
309
310
])
def test_uses_mrope(model_id, uses_mrope):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.uses_mrope == uses_mrope
311
312
313


def test_generation_config_loading():
314
    model_id = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
315

316
    # When set generation_config to "vllm", the default generation config
317
318
319
320
321
322
323
324
    # will not be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
325
                               generation_config="vllm")
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config="auto")

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config="auto",
        override_generation_config=override_generation_config)

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

367
    # When generation_config is set to "vllm" and override_generation_config
368
369
370
371
372
373
374
375
376
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
377
        generation_config="vllm",
378
379
        override_generation_config=override_generation_config)

380
    assert model_config.get_diff_sampling_param() == override_generation_config