test_generate.py 5.79 KB
Newer Older
1
2
3
import weakref
from typing import List

4
5
import pytest

6
7
from vllm import LLM, RequestOutput, SamplingParams

8
from ...conftest import cleanup
9
from ..openai.test_vision import TEST_IMAGE_URLS
10
11
12
13
14
15
16
17
18

MODEL_NAME = "facebook/opt-125m"

PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
19

20
21
22
23
24
25
TOKEN_IDS = [
    [0],
    [0, 1],
    [0, 2, 1],
    [0, 3, 1, 2],
]
26

27
28
29
30
31
32

@pytest.fixture(scope="module")
def llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
    llm = LLM(model=MODEL_NAME,
33
              max_num_batched_tokens=4096,
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
              tensor_parallel_size=1,
              gpu_memory_utilization=0.10,
              enforce_eager=True)

    with llm.deprecate_legacy_api():
        yield weakref.proxy(llm)

        del llm

    cleanup()


def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
    assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=prompt,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(prompt, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate({"prompt": prompt},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
                                                    prompt_token_ids):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
                                 sampling_params=sampling_params)

    v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
                             sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompts'"):
        v1_output = llm.generate(prompts=PROMPTS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert_outputs_equal(v1_output, v2_output)

    v2_output = llm.generate(
        [{
            "prompt": p
        } for p in PROMPTS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
        v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
                                 sampling_params=sampling_params)

    v2_output = llm.generate(
        [{
            "prompt_token_ids": p
        } for p in TOKEN_IDS],
        sampling_params=sampling_params,
    )
    assert_outputs_equal(v1_output, v2_output)
117
118


119
120
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
121
122
123
124
125
126
127
128
    sampling_params = [
        SamplingParams(temperature=0.01, top_p=0.95),
        SamplingParams(temperature=0.3, top_p=0.95),
        SamplingParams(temperature=0.7, top_p=0.95),
        SamplingParams(temperature=0.99, top_p=0.95),
    ]

    # Multiple SamplingParams should be matched with each prompt
129
130
    outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
    assert len(PROMPTS) == len(outputs)
131
132
133

    # Exception raised, if the size of params does not match the size of prompts
    with pytest.raises(ValueError):
134
        outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
135
136
137

    # Single SamplingParams should be applied to every prompt
    single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
138
139
    outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
    assert len(PROMPTS) == len(outputs)
140
141

    # sampling_params is None, default params should be applied
142
143
    outputs = llm.generate(PROMPTS, sampling_params=None)
    assert len(PROMPTS) == len(outputs)
nunjunj's avatar
nunjunj committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


def test_chat():

    llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

    prompt1 = "Explain the concept of entropy."
    messages = [
        {
            "role": "system",
            "content": "You are a helpful assistant"
        },
        {
            "role": "user",
            "content": prompt1
        },
    ]
    outputs = llm.chat(messages)
    assert len(outputs) == 1
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195


@pytest.mark.parametrize("image_urls",
                         [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
    llm = LLM(
        model="microsoft/Phi-3.5-vision-instruct",
        dtype="bfloat16",
        max_model_len=4096,
        max_num_seqs=5,
        enforce_eager=True,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 2},
    )

    messages = [{
        "role":
        "user",
        "content": [
            *({
                "type": "image_url",
                "image_url": {
                    "url": image_url
                }
            } for image_url in image_urls),
            {
                "type": "text",
                "text": "What's in this image?"
            },
        ],
    }]
    outputs = llm.chat(messages)
    assert len(outputs) >= 0