test_config.py 12.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from dataclasses import asdict

5
import pytest
6
import os
7

8
from vllm.config import ModelConfig, PoolerConfig
9
10
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
11
from utils import models_path_prefix
12

13
14
from .conftest import MODEL_WEIGHTS_S3_BUCKET

15

16
17
18
@pytest.mark.parametrize(
    ("model_id", "expected_runner_type", "expected_task"),
    [
19
20
21
22
23
        (f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2", "generate",
         "generate"),
        (f"{MODEL_WEIGHTS_S3_BUCKET}/intfloat/e5-mistral-7b-instruct",
         "pooling", "embed"),
        (f"{MODEL_WEIGHTS_S3_BUCKET}/jason9693/Qwen2.5-1.5B-apeach", "pooling",
24
         "classify"),
25
26
        (f"{MODEL_WEIGHTS_S3_BUCKET}/cross-encoder/ms-marco-MiniLM-L-6-v2",
         "pooling", "score"),
zhuwenwen's avatar
zhuwenwen committed
27
        (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "reward"),
zhuwenwen's avatar
zhuwenwen committed
28
        (os.path.join(models_path_prefix, "openai/whisper-small"), "transcription", "transcription"),
29
30
31
    ],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
32
33
34
35
36
37
38
39
40
41
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

42
    assert config.runner_type == expected_runner_type
43
44
45
46
    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
zhuwenwen's avatar
zhuwenwen committed
47
    (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "generate"),
48
49
50
51
52
53
54
55
56
57
58
59
60
61
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


62
MODEL_IDS_EXPECTED = [
63
64
65
    (os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"), 32768),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"), 4096),
    (os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.2"), 32768),
66
67
68
69
70
71
72
73
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
74
75
        task="auto",
        tokenizer=model_id,
76
77
78
79
80
81
82
83
84
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

85
86
87
88
89
90
91

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
92
        os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
93
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
94
        tokenizer=os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B"),
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
110
        os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
111
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
112
        tokenizer=os.path.join(models_path_prefix, "mistralai/Mistral-7B-v0.1"),
113
114
115
116
117
118
119
120
121
122
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
123
124
125
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


126
127
128
129
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
130
    model_config = ModelConfig(
131
132
133
134
135
136
137
138
139
140
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

141
142
    pooling_config = model_config._init_pooler_config(None)
    assert pooling_config is not None
143

144
145
    assert pooling_config.normalize
    assert pooling_config.pooling_type == PoolingType.MEAN.name
146
147
148
149
150
151


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
    model_id = "sentence-transformers/all-MiniLM-L12-v2"
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               revision=None)

    override_config = PoolerConfig(pooling_type='CLS', normalize=True)

    pooling_config = model_config._init_pooler_config(override_config)
    assert pooling_config is not None
    assert asdict(pooling_config) == asdict(override_config)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
    bge_model_config = ModelConfig(
        model="BAAI/bge-base-en-v1.5",
        task="auto",
        tokenizer="BAAI/bge-base-en-v1.5",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    bert_bge_model_config = bge_model_config._get_encoder_config()

    assert bert_bge_model_config["max_seq_length"] == 512
    assert bert_bge_model_config["do_lower_case"]


188
def test_rope_customization():
189
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
190
    TEST_ROPE_THETA = 16_000_000.0
191
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
192
193

    llama_model_config = ModelConfig(
194
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
195
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
196
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
197
198
199
200
201
202
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
203
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
204
205
206
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
207
        os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
208
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
209
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
210
211
212
213
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
214
215
216
217
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
            "rope_theta": TEST_ROPE_THETA,
        },
218
219
220
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
221
222
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
223
224
    assert llama_model_config.max_model_len == 16384

225
    longchat_model_config = ModelConfig(
226
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
227
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
228
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
229
230
231
232
233
234
235
236
237
238
239
240
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
241
        os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
242
        task="auto",
zhuwenwen's avatar
zhuwenwen committed
243
        tokenizer=os.path.join(models_path_prefix, "lmsys/longchat-13b-16k"),
244
245
246
247
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
248
249
250
        hf_overrides={
            "rope_scaling": TEST_ROPE_SCALING,
        },
251
252
253
254
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096
255
256


257
258
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Encoder Decoder models not supported on ROCm.")
259
260
261
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
    ("facebook/opt-125m", False),
    ("facebook/bart-base", True),
262
    ("meta-llama/Llama-3.2-1B-Instruct", False),
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    ("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.is_encoder_decoder == is_encoder_decoder


@pytest.mark.parametrize(("model_id", "uses_mrope"), [
    ("facebook/opt-125m", False),
    ("Qwen/Qwen2-VL-2B-Instruct", True),
])
def test_uses_mrope(model_id, uses_mrope):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )

    assert config.uses_mrope == uses_mrope
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364


def test_generation_config_loading():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"

    # When set generation_config to None, the default generation config
    # will not be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config=None)
    assert model_config.get_diff_sampling_param() == {}

    # When set generation_config to "auto", the default generation config
    # should be loaded.
    model_config = ModelConfig(model_id,
                               task="auto",
                               tokenizer=model_id,
                               tokenizer_mode="auto",
                               trust_remote_code=False,
                               seed=0,
                               dtype="float16",
                               generation_config="auto")

    correct_generation_config = {
        "repetition_penalty": 1.1,
        "temperature": 0.7,
        "top_p": 0.8,
        "top_k": 20,
    }

    assert model_config.get_diff_sampling_param() == correct_generation_config

    # The generation config could be overridden by the user.
    override_generation_config = {"temperature": 0.5, "top_k": 5}

    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config="auto",
        override_generation_config=override_generation_config)

    override_result = correct_generation_config.copy()
    override_result.update(override_generation_config)

    assert model_config.get_diff_sampling_param() == override_result

    # When generation_config is set to None and override_generation_config
    # is set, the override_generation_config should be used directly.
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        generation_config=None,
        override_generation_config=override_generation_config)

    assert model_config.get_diff_sampling_param() == override_generation_config