test_config.py 4.02 KB
Newer Older
1
2
import pytest

3
4
from vllm.config import ModelConfig

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
MODEL_IDS_EXPECTED = [
    ("Qwen/Qwen1.5-7B", 32768),
    ("mistralai/Mistral-7B-v0.1", 4096),
    ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]


@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
    model_id, expected = model_id_expected
    model_config = ModelConfig(
        model_id,
        model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        disable_sliding_window=True,
    )
    assert model_config.max_model_len == expected

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def test_get_sliding_window():
    TEST_SLIDING_WINDOW = 4096
    # Test that the sliding window is correctly computed.
    # For Qwen1.5/Qwen2, get_sliding_window() should be None
    # when use_sliding_window is False.
    qwen2_model_config = ModelConfig(
        "Qwen/Qwen1.5-7B",
        "Qwen/Qwen1.5-7B",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )

    qwen2_model_config.hf_config.use_sliding_window = False
    qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
    assert qwen2_model_config.get_sliding_window() is None

    qwen2_model_config.hf_config.use_sliding_window = True
    assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

    mistral_model_config = ModelConfig(
        "mistralai/Mistral-7B-v0.1",
        "mistralai/Mistral-7B-v0.1",
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )
    mistral_model_config.hf_config.sliding_window = None
    assert mistral_model_config.get_sliding_window() is None

    mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
63
64
65
    assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


66
def test_rope_customization():
67
    TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
68
    TEST_ROPE_THETA = 16_000_000.0
69
70
71
72
73
74
75
76
77
78

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "meta-llama/Meta-Llama-3-8B-Instruct",
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
79
    assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
80
81
82
83
84
85
86
87
88
89
    assert llama_model_config.max_model_len == 8192

    llama_model_config = ModelConfig(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "meta-llama/Meta-Llama-3-8B-Instruct",
        tokenizer_mode="auto",
        trust_remote_code=False,
        dtype="float16",
        seed=0,
        rope_scaling=TEST_ROPE_SCALING,
90
        rope_theta=TEST_ROPE_THETA,
91
92
93
    )
    assert getattr(llama_model_config.hf_config, "rope_scaling",
                   None) == TEST_ROPE_SCALING
94
95
    assert getattr(llama_model_config.hf_config, "rope_theta",
                   None) == TEST_ROPE_THETA
96
97
    assert llama_model_config.max_model_len == 16384

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # TODO: add these back when the rope configs are fixed
    # LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
    # longchat_model_config = ModelConfig(
    #     "lmsys/longchat-13b-16k",
    #     "lmsys/longchat-13b-16k",
    #     tokenizer_mode="auto",
    #     trust_remote_code=False,
    #     dtype="float16",
    #     seed=0,
    # )
    # assert getattr(longchat_model_config.hf_config, "rope_scaling",
    #                None) == LONGCHAT_ROPE_SCALING
    # assert longchat_model_config.max_model_len == 16384

    # longchat_model_config = ModelConfig(
    #     "lmsys/longchat-13b-16k",
    #     "lmsys/longchat-13b-16k",
    #     tokenizer_mode="auto",
    #     trust_remote_code=False,
    #     dtype="float16",
    #     seed=0,
    #     rope_scaling=TEST_ROPE_SCALING,
    # )
    # assert getattr(longchat_model_config.hf_config, "rope_scaling",
    #                None) == TEST_ROPE_SCALING
    # assert longchat_model_config.max_model_len == 4096