test_chat.py 5.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import weakref
4

zhuwenwen's avatar
zhuwenwen committed
5
import os
6
7
8
import pytest

from vllm import LLM
9
from vllm.distributed import cleanup_dist_env_and_memory
10
from vllm.sampling_params import SamplingParams
11

12
13

from utils import models_path_prefix
14
from ..openai.test_vision import TEST_IMAGE_ASSETS
15
16


17
18
19
20
@pytest.fixture(scope="function")
def text_llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
21
    llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0)
22

23
    yield weakref.proxy(llm)
24

25
    del llm
26
27
28
29

    cleanup_dist_env_and_memory()


30
31
32
33
34
35
36
37
38
@pytest.fixture(scope="function")
def llm_for_failure_test():
    """
    Fixture for testing issue #26081.
    Uses a small max_model_len to easily trigger length errors.
    """
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
    llm = LLM(
39
        model=os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
40
41
42
43
44
        enforce_eager=True,
        seed=0,
        max_model_len=128,
        disable_log_stats=True,
    )
45

46
    yield weakref.proxy(llm)
47

48
    del llm
49
50
51
52
53

    cleanup_dist_env_and_memory()


def test_chat(text_llm):
54
55
    prompt1 = "Explain the concept of entropy."
    messages = [
56
57
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": prompt1},
58
    ]
59
    outputs = text_llm.chat(messages)
60
61
62
    assert len(outputs) == 1


63
def test_multi_chat(text_llm):
64
65
66
67
    prompt1 = "Explain the concept of entropy."
    prompt2 = "Explain what among us is."

    conversation1 = [
68
69
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": prompt1},
70
71
72
    ]

    conversation2 = [
73
74
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": prompt2},
75
76
77
78
    ]

    messages = [conversation1, conversation2]

79
    outputs = text_llm.chat(messages)
80
81
82
    assert len(outputs) == 2


83
84
85
86
@pytest.fixture(scope="function")
def vision_llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
87
    llm = LLM(
zhuwenwen's avatar
zhuwenwen committed
88
        model=os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
89
90
91
92
93
        max_model_len=4096,
        max_num_seqs=5,
        enforce_eager=True,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 2},
94
        seed=0,
95
96
    )

97
    yield weakref.proxy(llm)
98

99
    del llm
100
101
102
103

    cleanup_dist_env_and_memory()


104
105
106
@pytest.mark.parametrize(
    "image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
)
107
def test_chat_multi_image(vision_llm, image_urls: list[str]):
108
109
110
111
112
113
114
115
116
117
118
119
    messages = [
        {
            "role": "user",
            "content": [
                *(
                    {"type": "image_url", "image_url": {"url": image_url}}
                    for image_url in image_urls
                ),
                {"type": "text", "text": "What's in this image?"},
            ],
        }
    ]
120
    outputs = vision_llm.chat(messages)
121
    assert len(outputs) >= 0
122
123


124
def test_llm_chat_tokenization_no_double_bos(text_llm):
125
126
127
128
129
    """
    LLM.chat() should not add special tokens when using chat templates.
    Check we get a single BOS token for llama chat.
    """
    messages = [
130
131
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "Hello!"},
132
    ]
133
    outputs = text_llm.chat(messages)
134
    assert len(outputs) == 1
135
136

    prompt_token_ids = outputs[0].prompt_token_ids
137
138
    assert prompt_token_ids is not None

139
    bos_token = text_llm.get_tokenizer().bos_token_id
140
141
142
143

    # Ensure we have a single BOS
    assert prompt_token_ids[0] == bos_token
    assert prompt_token_ids[1] != bos_token, "Double BOS"
144
145
146
147
148
149
150
151
152
153
154
155
156


@pytest.fixture(scope="function")
def thinking_llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
    llm = LLM(
        model="Qwen/Qwen3-0.6B",
        max_model_len=4096,
        enforce_eager=True,
        seed=0,
    )

157
    yield weakref.proxy(llm)
158

159
    del llm
160
161
162
163
164
165
166

    cleanup_dist_env_and_memory()


@pytest.mark.parametrize("enable_thinking", [True, False])
def test_chat_extra_kwargs(thinking_llm, enable_thinking):
    messages = [
167
168
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "What is 1+1?"},
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    ]

    outputs = thinking_llm.chat(
        messages,
        chat_template_kwargs={"enable_thinking": enable_thinking},
    )
    assert len(outputs) == 1

    prompt_token_ids = outputs[0].prompt_token_ids
    assert prompt_token_ids is not None

    think_id = thinking_llm.get_tokenizer().get_vocab()["<think>"]

    if enable_thinking:
        assert think_id not in prompt_token_ids
    else:
        # The chat template includes dummy thinking process
        assert think_id in prompt_token_ids
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215


def test_chat_batch_failure_cleanup(llm_for_failure_test):
    """
    Tests that if a batch call to llm.chat() fails mid-way
    (e.g., due to one invalid prompt), the requests that
    were already enqueued are properly aborted and do not
    pollute the queue for subsequent calls.
    (Fixes Issue #26081)
    """
    llm = llm_for_failure_test
    valid_msg = [{"role": "user", "content": "Hello"}]
    long_text = "This is a very long text to test the error " * 50
    invalid_msg = [{"role": "user", "content": long_text}]
    batch_1 = [
        valid_msg,
        valid_msg,
        invalid_msg,
    ]
    batch_2 = [
        valid_msg,
        valid_msg,
    ]
    sampling_params = SamplingParams(temperature=0, max_tokens=10)
    with pytest.raises(ValueError, match="longer than the maximum model length"):
        llm.chat(batch_1, sampling_params=sampling_params)
    outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
    assert len(outputs_2) == len(batch_2)
    assert llm.llm_engine.get_num_unfinished_requests() == 0