test_config.py 5.16 KB
Newer Older
1
2
import pytest

3
4
from vllm.config import ModelConfig

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

@pytest.mark.parametrize(("model_id", "expected_task"), [
    ("facebook/opt-125m", "generate"),
    ("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
    config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

    assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
    ("facebook/opt-125m", "embedding"),
    ("intfloat/e5-mistral-7b-instruct", "generate"),
])
def test_incorrect_task(model_id, bad_task):
    with pytest.raises(ValueError, match=r"does not support the .* task"):
        ModelConfig(
            model_id,
            task=bad_task,
            tokenizer=model_id,
            tokenizer_mode="auto",
            trust_remote_code=False,
            seed=0,
            dtype="float16",
        )


41
42
43
44
45
46
47
48
49
50
51
52
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
53
54
        task="auto",
        tokenizer=model_id,
55
56
57
58
59
60
61
62
63
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

64
65
66
67
68
69
70
71

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
        "Qwen/Qwen1.5-7B",
72
73
        task="auto",
        tokenizer="Qwen/Qwen1.5-7B",
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
        "mistralai/Mistral-7B-v0.1",
90
91
        task="auto",
        tokenizer="mistralai/Mistral-7B-v0.1",
92
93
94
95
96
97
98
99
100
101
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
102
103
104
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


105
def test_rope_customization():
106
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
107
    TEST_ROPE_THETA = 16_000_000.0
108
    LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
109
110
111

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
112
113
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
114
115
116
117
118
119
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
120
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
121
122
123
124
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
125
126
        task="auto",
        tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
127
128
129
130
131
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
        rope_scaling=TEST_ROPE_SCALING,
132
        rope_theta=TEST_ROPE_THETA,
133
134
135
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
136
137
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
138
139
    assert llama_model_config.max_model_len == 16384

140
141
    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
142
143
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
144
145
146
147
148
149
150
151
152
153
154
155
156
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
    assert all(
        longchat_model_config.hf_config.rope_scaling.get(key) == value
        for key, value in LONGCHAT_ROPE_SCALING.items())
    assert longchat_model_config.max_model_len == 16384

    longchat_model_config = ModelConfig(
        "lmsys/longchat-13b-16k",
157
158
        task="auto",
        tokenizer="lmsys/longchat-13b-16k",
159
160
161
162
163
164
165
166
167
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
        rope_scaling=TEST_ROPE_SCALING,
    )
    assert getattr(longchat_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
    assert longchat_model_config.max_model_len == 4096