test_serving_chat.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import asyncio
7
from contextlib import suppress
8
from dataclasses import dataclass, field
9
from typing import TYPE_CHECKING, Any
10
from unittest.mock import AsyncMock, MagicMock
11

12
import pytest
13
import pytest_asyncio
14

15
from vllm.config.multimodal import MultiModalConfig
16
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
17
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
18
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
19
from vllm.transformers_utils.tokenizer import get_tokenizer
20
from vllm.v1.engine.async_llm import AsyncLLM
21

22
23
24
25
26
27
28
29
30
31
32
from ...utils import RemoteOpenAIServer

if TYPE_CHECKING:
    from openai import OpenAI

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
33

34
35
36
37
38
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


39
40
41
42
43
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
44
45
46
47
def with_tool_parser(request) -> bool:
    return request.param


48
@pytest.fixture(scope="module")
49
50
51
52
53
54
55
56
57
58
59
60
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
61
62
63
64
65
66
67
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
68
69
70
71
    return args


@pytest.fixture(scope="module")
72
73
74
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
75
    with monkeypatch_module.context() as m:
76
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
77
78
79
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
80
81
82
83
84
85
86
87
88
89
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
108
                    },
109
                    "required": ["city", "state", "unit"],
110
111
                },
            },
112
113
        }
    ]
114
115

    messages = [
116
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
117
118
119
    ]

    stream = await gptoss_client.chat.completions.create(
120
121
122
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
123
124
        stream=True,
    )
125
126
127

    name = None
    args_buf = ""
128
    content_buf = ""
129
130
131
132
133
134
135
136
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
137
138
139
140
141
142
143
144
145
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
146
147
148


@pytest.mark.asyncio
149
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
150
151
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
167
                    },
168
                    "required": ["city", "state", "unit"],
169
170
                },
            },
171
172
        }
    ]
173
174

    messages = [
175
176
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
191
    assert not first_msg.content
192
193

    messages.append({"role": "assistant", "content": args1})
194
195
196
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
197
198
199
200
201
202
203
204

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
205
206
207
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
208
209


210
MODEL_NAME = "openai-community/gpt2"
211
MODEL_NAME_SHORT = "gpt2"
212
CHAT_TEMPLATE = "Dummy chat template for testing {}"
213
214
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
215
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
216
]
217
218


219
220
221
222
223
@dataclass
class MockHFConfig:
    model_type: str = "any"


224
225
@dataclass
class MockModelConfig:
226
    task = "generate"
227
    runner_type = "generate"
228
229
230
231
232
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
233
    multimodal_config = MultiModalConfig()
234
    hf_config = MockHFConfig()
235
    logits_processor_pattern = None
236
    diff_sampling_param: dict | None = None
237
    allowed_local_media_path: str = ""
238
    allowed_media_domains: list[str] | None = None
239
    encoder_config = None
240
    generation_config: str = "auto"
241
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
242
    skip_tokenizer_init = False
243
244
245

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
246
247


248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def _build_serving_chat(
    engine: AsyncLLM, model_config: MockModelConfig
) -> OpenAIServingChat:
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
        model_config=model_config,
    )
    serving_chat = OpenAIServingChat(
        engine,
        model_config,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
275
276
277
278
279
280
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


281
282
283
@dataclass
class MockEngine:
    async def get_model_config(self):
284
        return MockModelConfig()
285
286
287


async def _async_serving_chat_init():
288
289
290
    engine = MockEngine()
    model_config = await engine.get_model_config()

291
    models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
292
293
294
295
296
297
298
299
300
    serving_completion = OpenAIServingChat(
        engine,
        model_config,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
301
302
303
304
305
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
306
    assert serving_completion.chat_template == CHAT_TEMPLATE
307
308


309
310
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
311
    mock_engine = MagicMock(spec=AsyncLLM)
312
313
314
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

315
    serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


336
337
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
338
    mock_engine = MagicMock(spec=AsyncLLM)
339
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
340
    mock_engine.errored = False
341

342
    serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
343

344
345
    req = ChatCompletionRequest(
        model=MODEL_NAME,
346
        messages=[{"role": "user", "content": "what is 1+1?"}],
347
348
349
    )

    with suppress(Exception):
350
        await serving_chat.create_chat_completion(req)
351
352
353
354
355

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
356
        await serving_chat.create_chat_completion(req)
357
358

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
359

360
361
362
363
364
365
366
367
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
368
    mock_engine = MagicMock(spec=AsyncLLM)
369
370
371
372
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

    # Initialize the serving chat
373
    serving_chat = _build_serving_chat(mock_engine, mock_model_config)
374
375
376
377

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
378
        messages=[{"role": "user", "content": "what is 1+1?"}],
379
380
381
    )

    with suppress(Exception):
382
        await serving_chat.create_chat_completion(req)
383
384
385
386
387
388
389

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
390
        await serving_chat.create_chat_completion(req)
391
392
393
394
395
396
397

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
398
        await serving_chat.create_chat_completion(req)
399
400
401
402
403
404
405
406
407
408
409

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
410
    mock_engine = MagicMock(spec=AsyncLLM)
411
412
413
414
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

    # Initialize the serving chat
415
    serving_chat = _build_serving_chat(mock_engine, mock_model_config)
416
417
418
419

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
420
        messages=[{"role": "user", "content": "what is 1+1?"}],
421
422
423
    )

    with suppress(Exception):
424
        await serving_chat.create_chat_completion(req)
425
426
427
428
429
430
431

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
432
        await serving_chat.create_chat_completion(req)
433
434
435
436
437
438
439

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
440
        await serving_chat.create_chat_completion(req)
441
442
443

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

444

445
446
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
447
448
449
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
450
        "repetition_penalty": 1.05,
451
452
    }

453
    mock_engine = MagicMock(spec=AsyncLLM)
454
455
456
457
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

    # Initialize the serving chat
458
    serving_chat = _build_serving_chat(mock_engine, mock_model_config)
459

460
461
    req = ChatCompletionRequest(
        model=MODEL_NAME,
462
        messages=[{"role": "user", "content": "what is 1+1?"}],
463
464
465
    )

    with suppress(Exception):
466
        await serving_chat.create_chat_completion(req)
467
468
469
470
471
472
473
474

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
475
        await serving_chat.create_chat_completion(req)
476
477
478
479
480
481
482
483

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
484
        await serving_chat.create_chat_completion(req)
485
486
487

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
488
489


490
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
491
@pytest.mark.asyncio
492
async def test_serving_chat_did_set_correct_cache_salt(model_type):
493
    mock_model_config = MockModelConfig()
494
    mock_model_config.hf_config.model_type = model_type
495

496
    mock_engine = MagicMock(spec=AsyncLLM)
497
498
499
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

500
    serving_chat = _build_serving_chat(mock_engine, mock_model_config)
501
502
503
504

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
505
        messages=[{"role": "user", "content": "what is 1+1?"}],
506
507
    )

508
    # By default, cache_salt in the engine prompt is not set
509
    with suppress(Exception):
510
        await serving_chat.create_chat_completion(req)
511
512
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
513
514
515
516

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
517
        await serving_chat.create_chat_completion(req)
518
519
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"