test_serving_chat.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import asyncio
7
from contextlib import suppress
8
from dataclasses import dataclass, field
9
from typing import TYPE_CHECKING, Any
10
from unittest.mock import AsyncMock, MagicMock
11

12
import pytest
13
import pytest_asyncio
14

15
from vllm.config.multimodal import MultiModalConfig
16
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
17
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
18
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
19
from vllm.transformers_utils.tokenizer import get_tokenizer
20
from vllm.v1.engine.async_llm import AsyncLLM
21

22
23
24
25
26
27
28
29
30
31
32
from ...utils import RemoteOpenAIServer

if TYPE_CHECKING:
    from openai import OpenAI

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
33

34
35
36
37
38
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


39
40
41
42
43
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
44
45
46
47
def with_tool_parser(request) -> bool:
    return request.param


48
@pytest.fixture(scope="module")
49
50
51
52
53
54
55
56
57
58
59
60
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
61
62
63
64
65
66
67
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
68
69
70
71
    return args


@pytest.fixture(scope="module")
72
73
74
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
75
    with monkeypatch_module.context() as m:
76
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
77
78
79
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
80
81
82
83
84
85
86
87
88
89
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
108
                    },
109
                    "required": ["city", "state", "unit"],
110
111
                },
            },
112
113
        }
    ]
114
115

    messages = [
116
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
117
118
119
    ]

    stream = await gptoss_client.chat.completions.create(
120
121
122
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
123
124
        stream=True,
    )
125
126
127

    name = None
    args_buf = ""
128
    content_buf = ""
129
130
131
132
133
134
135
136
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
137
138
139
140
141
142
143
144
145
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
146
147
148


@pytest.mark.asyncio
149
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
150
151
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
167
                    },
168
                    "required": ["city", "state", "unit"],
169
170
                },
            },
171
172
        }
    ]
173
174

    messages = [
175
176
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
191
    assert not first_msg.content
192
193

    messages.append({"role": "assistant", "content": args1})
194
195
196
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
197
198
199
200
201
202
203
204

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
205
206
207
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
208
209


210
MODEL_NAME = "openai-community/gpt2"
211
MODEL_NAME_SHORT = "gpt2"
212
CHAT_TEMPLATE = "Dummy chat template for testing {}"
213
214
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
215
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
216
]
217
218


219
220
221
222
223
@dataclass
class MockHFConfig:
    model_type: str = "any"


224
225
@dataclass
class MockModelConfig:
226
    task = "generate"
227
    runner_type = "generate"
228
229
230
231
232
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
233
    multimodal_config = MultiModalConfig()
234
    hf_config = MockHFConfig()
235
    logits_processor_pattern = None
236
    diff_sampling_param: dict | None = None
237
    allowed_local_media_path: str = ""
238
    allowed_media_domains: list[str] | None = None
239
    encoder_config = None
240
    generation_config: str = "auto"
241
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
242
    skip_tokenizer_init = False
243
244
245

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
246
247


248
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
271
272
273
274
275
276
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


277
278
@dataclass
class MockEngine:
279
280
281
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
    processor: MagicMock = field(default_factory=MagicMock)
    io_processor: MagicMock = field(default_factory=MagicMock)
282
283
284


async def _async_serving_chat_init():
285
286
    engine = MockEngine()

287
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
288
289
290
291
292
293
294
295
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
296
297
298
299
300
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
301
    assert serving_completion.chat_template == CHAT_TEMPLATE
302
303


304
305
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
306
    mock_engine = MagicMock(spec=AsyncLLM)
307
308
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
309
310
311
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
312

313
    serving_chat = _build_serving_chat(mock_engine)
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


334
335
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
336
    mock_engine = MagicMock(spec=AsyncLLM)
337
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
338
    mock_engine.errored = False
339
340
341
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
342

343
    serving_chat = _build_serving_chat(mock_engine)
344

345
346
    req = ChatCompletionRequest(
        model=MODEL_NAME,
347
        messages=[{"role": "user", "content": "what is 1+1?"}],
348
349
350
    )

    with suppress(Exception):
351
        await serving_chat.create_chat_completion(req)
352
353
354
355
356

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
357
        await serving_chat.create_chat_completion(req)
358
359

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
360

361
362
363
364
365
366
367
368
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
369
    mock_engine = MagicMock(spec=AsyncLLM)
370
371
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
372
373
374
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
375
376

    # Initialize the serving chat
377
    serving_chat = _build_serving_chat(mock_engine)
378
379
380
381

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
382
        messages=[{"role": "user", "content": "what is 1+1?"}],
383
384
385
    )

    with suppress(Exception):
386
        await serving_chat.create_chat_completion(req)
387
388
389
390
391
392
393

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
394
        await serving_chat.create_chat_completion(req)
395
396
397
398
399
400
401

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
402
        await serving_chat.create_chat_completion(req)
403
404
405
406
407
408
409
410
411
412
413

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
414
    mock_engine = MagicMock(spec=AsyncLLM)
415
416
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
417
418
419
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
420
421

    # Initialize the serving chat
422
    serving_chat = _build_serving_chat(mock_engine)
423
424
425
426

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
427
        messages=[{"role": "user", "content": "what is 1+1?"}],
428
429
430
    )

    with suppress(Exception):
431
        await serving_chat.create_chat_completion(req)
432
433
434
435
436
437
438

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
439
        await serving_chat.create_chat_completion(req)
440
441
442
443
444
445
446

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
447
        await serving_chat.create_chat_completion(req)
448
449
450

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

451

452
453
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
454
455
456
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
457
        "repetition_penalty": 1.05,
458
459
    }

460
    mock_engine = MagicMock(spec=AsyncLLM)
461
462
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
463
464
465
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
466
467

    # Initialize the serving chat
468
    serving_chat = _build_serving_chat(mock_engine)
469

470
471
    req = ChatCompletionRequest(
        model=MODEL_NAME,
472
        messages=[{"role": "user", "content": "what is 1+1?"}],
473
474
475
    )

    with suppress(Exception):
476
        await serving_chat.create_chat_completion(req)
477
478
479
480
481
482
483
484

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
485
        await serving_chat.create_chat_completion(req)
486
487
488
489
490
491
492
493

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
494
        await serving_chat.create_chat_completion(req)
495
496
497

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
498
499


500
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
501
@pytest.mark.asyncio
502
async def test_serving_chat_did_set_correct_cache_salt(model_type):
503
    mock_model_config = MockModelConfig()
504
    mock_model_config.hf_config.model_type = model_type
505

506
    mock_engine = MagicMock(spec=AsyncLLM)
507
508
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
509
510
511
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
512

513
    serving_chat = _build_serving_chat(mock_engine)
514
515
516
517

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
518
        messages=[{"role": "user", "content": "what is 1+1?"}],
519
520
    )

521
    # By default, cache_salt in the engine prompt is not set
522
    with suppress(Exception):
523
        await serving_chat.create_chat_completion(req)
524
525
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
526
527
528
529

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
530
        await serving_chat.create_chat_completion(req)
531
532
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"