test_tokenization.py 5.65 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import pytest
4
import os
5
import pytest_asyncio
6
7
8
9
10
import requests

from vllm.transformers_utils.tokenizer import get_tokenizer

from ...utils import RemoteOpenAIServer
11
12
from .test_completion import zephyr_lora_added_tokens_files  # noqa: F401
from .test_completion import zephyr_lora_files  # noqa: F401
13
from ...utils import models_path_prefix
14
15

# any model with a chat template should work here
16
MODEL_NAME = os.path.join(models_path_prefix, "HuggingFaceH4/zephyr-7b-beta")
17
18
19


@pytest.fixture(scope="module")
20
def server(zephyr_lora_added_tokens_files: str):  # noqa: F811
21
22
23
24
25
26
27
28
29
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
30
31
32
33
34
35
        # lora config
        "--enable-lora",
        "--lora-modules",
        f"zephyr-lora2={zephyr_lora_added_tokens_files}",
        "--max-lora-rank",
        "64",
36
37
38
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
39
40
41
        yield remote_server


42
43
44
45
46
47
48
@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
                   zephyr_lora_added_tokens_files: str):  # noqa: F811
    return zephyr_lora_added_tokens_files if (
        model_name == "zephyr-lora2") else model_name


49
50
51
52
@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client
53
54
55
56


@pytest.mark.asyncio
@pytest.mark.parametrize(
57
58
59
    "model_name,tokenizer_name",
    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
    indirect=["tokenizer_name"],
60
)
61
62
63
64
65
async def test_tokenize_completions(
    server: RemoteOpenAIServer,
    model_name: str,
    tokenizer_name: str,
):
66
67
    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
                              tokenizer_mode="fast")
68
69

    for add_special in [False, True]:
70
        prompt = "vllm1 This is a test prompt."
71
72
        tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

73
        response = requests.post(server.url_for("tokenize"),
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
                                 json={
                                     "add_special_tokens": add_special,
                                     "model": model_name,
                                     "prompt": prompt
                                 })
        response.raise_for_status()

        assert response.json() == {
            "tokens": tokens,
            "count": len(tokens),
            "max_model_len": 8192
        }


@pytest.mark.asyncio
@pytest.mark.parametrize(
90
91
92
    "model_name,tokenizer_name",
    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
    indirect=["tokenizer_name"],
93
)
94
95
96
97
98
async def test_tokenize_chat(
    server: RemoteOpenAIServer,
    model_name: str,
    tokenizer_name: str,
):
99
100
    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
                              tokenizer_mode="fast")
101
102
103
104
105
106
107
108
109
110
111

    for add_generation in [False, True]:
        for add_special in [False, True]:
            conversation = [{
                "role": "user",
                "content": "Hi there!"
            }, {
                "role": "assistant",
                "content": "Nice to meet you!"
            }, {
                "role": "user",
112
                "content": "Can I ask a question? vllm1"
113
            }]
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            for continue_final in [False, True]:
                if add_generation and continue_final:
                    continue
                if continue_final:
                    conversation.append({
                        "role": "assistant",
                        "content": "Sure,"
                    })

                prompt = tokenizer.apply_chat_template(
                    add_generation_prompt=add_generation,
                    continue_final_message=continue_final,
                    conversation=conversation,
                    tokenize=False)
                tokens = tokenizer.encode(prompt,
                                          add_special_tokens=add_special)

131
                response = requests.post(server.url_for("tokenize"),
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                                         json={
                                             "add_generation_prompt":
                                             add_generation,
                                             "continue_final_message":
                                             continue_final,
                                             "add_special_tokens": add_special,
                                             "messages": conversation,
                                             "model": model_name
                                         })
                response.raise_for_status()

                assert response.json() == {
                    "tokens": tokens,
                    "count": len(tokens),
                    "max_model_len": 8192
                }
148
149
150
151


@pytest.mark.asyncio
@pytest.mark.parametrize(
152
153
154
    "model_name,tokenizer_name",
    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
    indirect=["tokenizer_name"],
155
)
156
157
158
159
160
async def test_detokenize(
    server: RemoteOpenAIServer,
    model_name: str,
    tokenizer_name: str,
):
161
162
    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
                              tokenizer_mode="fast")
163

164
    prompt = "This is a test prompt. vllm1"
165
166
    tokens = tokenizer.encode(prompt, add_special_tokens=False)

167
    response = requests.post(server.url_for("detokenize"),
168
169
170
171
172
173
174
                             json={
                                 "model": model_name,
                                 "tokens": tokens
                             })
    response.raise_for_status()

    assert response.json() == {"prompt": prompt}