test_generate.py 6.48 KB
Newer Older
1
2
3
import weakref
from typing import List

4
5
import pytest

6
7
from vllm import LLM, RequestOutput, SamplingParams

8
from ...conftest import cleanup
9
from ..openai.test_vision import TEST_IMAGE_URLS
10
11
12
13
14
15
16
17
18

MODEL_NAME = "facebook/opt-125m"

PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
19

20
21
22
23
24
25
TOKEN_IDS = [
    [0],
    [0, 1],
    [0, 2, 1],
    [0, 3, 1, 2],
]
26

27
28
29
30
31
32

@pytest.fixture(scope="module")
def llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
    llm = LLM(model=MODEL_NAME,
33
              max_num_batched_tokens=4096,
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
              tensor_parallel_size=1,
              gpu_memory_utilization=0.10,
              enforce_eager=True)

    with llm.deprecate_legacy_api():
        yield weakref.proxy(llm)

        del llm

    cleanup()


def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
    assert [o.outputs for o in o1] == [o.outputs for o in o2]


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=prompt,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(prompt, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate({"prompt": prompt},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
                                                    prompt_token_ids):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
                                 sampling_params=sampling_params)

    v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=PROMPTS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate(
        [{
            "prompt": p
        } for p in PROMPTS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(
        [{
            "prompt_token_ids": p
        } for p in TOKEN_IDS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)
117
118


119
120
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
121
122
123
124
125
126
127
128
    sampling_params = [
        SamplingParams(temperature=0.01, top_p=0.95),
        SamplingParams(temperature=0.3, top_p=0.95),
        SamplingParams(temperature=0.7, top_p=0.95),
        SamplingParams(temperature=0.99, top_p=0.95),
    ]

    # Multiple SamplingParams should be matched with each prompt
129
130
    outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert len(PROMPTS) == len(outputs)
131
132
133

    # Exception raised, if the size of params does not match the size of prompts
    with pytest.raises(ValueError):
134
        outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
135
136
137

    # Single SamplingParams should be applied to every prompt
    single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
138
139
    outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
    assert len(PROMPTS) == len(outputs)
140
141

    # sampling_params is None, default params should be applied
142
143
    outputs = llm.generate(PROMPTS, sampling_params=None)
    assert len(PROMPTS) == len(outputs)
nunjunj's avatar
nunjunj committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


def test_chat():

    llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

    prompt1 = "Explain the concept of entropy."
    messages = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt1
        },
    ]
    outputs = llm.chat(messages)
    assert len(outputs) == 1
163
164


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def test_multi_chat():

    llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

    prompt1 = "Explain the concept of entropy."
    prompt2 = "Explain what among us is."

    conversation1 = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt1
        },
    ]

    conversation2 = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt2
        },
    ]

    messages = [conversation1, conversation2]

    outputs = llm.chat(messages)
    assert len(outputs) == 2


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@pytest.mark.parametrize("image_urls",
                         [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
    llm = LLM(
        model="microsoft/Phi-3.5-vision-instruct",
        dtype="bfloat16",
        max_model_len=4096,
        max_num_seqs=5,
        enforce_eager=True,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 2},
    )

    messages = [{
        "role":
        "user",
        "content": [
            *({
                "type": "image_url",
                "image_url": {
                    "url": image_url
                }
            } for image_url in image_urls),
            {
                "type": "text",
                "text": "What's in this image?"
            },
        ],
    }]
    outputs = llm.chat(messages)
    assert len(outputs) >= 0