test_generate.py 6.66 KB
Newer Older
1
2
import weakref
from typing import List
3
import os
4

5
6
import pytest

7
8
from vllm import LLM, RequestOutput, SamplingParams

9
from ...conftest import cleanup
10
from ..openai.test_vision import TEST_IMAGE_URLS
11
from ...utils import models_path_prefix
12

13
MODEL_NAME = os.path.join(models_path_prefix, "facebook/opt-125m")
14
15
16
17
18
19
20

PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
21

22
23
24
25
26
27
TOKEN_IDS = [
    [0],
    [0, 1],
    [0, 2, 1],
    [0, 3, 1, 2],
]
28

29
30
31
32
33
34

@pytest.fixture(scope="module")
def llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
    llm = LLM(model=MODEL_NAME,
35
              max_num_batched_tokens=4096,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
              tensor_parallel_size=1,
              gpu_memory_utilization=0.10,
              enforce_eager=True)

    with llm.deprecate_legacy_api():
        yield weakref.proxy(llm)

        del llm

    cleanup()


def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
    assert [o.outputs for o in o1] == [o.outputs for o in o2]


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=prompt,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(prompt, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate({"prompt": prompt},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
                                                    prompt_token_ids):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
                                 sampling_params=sampling_params)

    v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=PROMPTS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate(
        [{
            "prompt": p
        } for p in PROMPTS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(
        [{
            "prompt_token_ids": p
        } for p in TOKEN_IDS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)
119
120


121
122
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
123
124
125
126
127
128
129
130
    sampling_params = [
        SamplingParams(temperature=0.01, top_p=0.95),
        SamplingParams(temperature=0.3, top_p=0.95),
        SamplingParams(temperature=0.7, top_p=0.95),
        SamplingParams(temperature=0.99, top_p=0.95),
    ]

    # Multiple SamplingParams should be matched with each prompt
131
132
    outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert len(PROMPTS) == len(outputs)
133
134
135

    # Exception raised, if the size of params does not match the size of prompts
    with pytest.raises(ValueError):
136
        outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
137
138
139

    # Single SamplingParams should be applied to every prompt
    single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
140
141
    outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
    assert len(PROMPTS) == len(outputs)
142
143

    # sampling_params is None, default params should be applied
144
145
    outputs = llm.generate(PROMPTS, sampling_params=None)
    assert len(PROMPTS) == len(outputs)
nunjunj's avatar
nunjunj committed
146
147
148
149


def test_chat():

150
    llm = LLM(model=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"))
nunjunj's avatar
nunjunj committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    prompt1 = "Explain the concept of entropy."
    messages = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt1
        },
    ]
    outputs = llm.chat(messages)
    assert len(outputs) == 1
165
166


167
168
def test_multi_chat():

169
    llm = LLM(model=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"))
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    prompt1 = "Explain the concept of entropy."
    prompt2 = "Explain what among us is."

    conversation1 = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt1
        },
    ]

    conversation2 = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt2
        },
    ]

    messages = [conversation1, conversation2]

    outputs = llm.chat(messages)
    assert len(outputs) == 2


202
203
204
205
@pytest.mark.parametrize("image_urls",
                         [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
    llm = LLM(
206
        model=os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        dtype="bfloat16",
        max_model_len=4096,
        max_num_seqs=5,
        enforce_eager=True,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 2},
    )

    messages = [{
        "role":
        "user",
        "content": [
            *({
                "type": "image_url",
                "image_url": {
                    "url": image_url
                }
            } for image_url in image_urls),
            {
                "type": "text",
                "text": "What's in this image?"
            },
        ],
    }]
    outputs = llm.chat(messages)
    assert len(outputs) >= 0