test_serving_chat.py 20.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
from contextlib import suppress
5
from dataclasses import dataclass, field
6
from typing import Any
7
from unittest.mock import AsyncMock, MagicMock
8

9
import pytest
10
import pytest_asyncio
11
from openai import OpenAI
12

13
from vllm.config.multimodal import MultiModalConfig
14
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
15
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
16
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
17
from vllm.transformers_utils.tokenizer import get_tokenizer
18
from vllm.v1.engine.async_llm import AsyncLLM
19

20
21
22
23
24
25
26
27
from ...utils import RemoteOpenAIServer

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
28

29
30
31
32
33
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


34
35
36
37
38
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
39
40
41
42
def with_tool_parser(request) -> bool:
    return request.param


43
@pytest.fixture(scope="module")
44
45
46
47
48
49
50
51
52
53
54
55
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
56
57
58
59
60
61
62
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
63
64
65
66
    return args


@pytest.fixture(scope="module")
67
68
69
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
70
    with monkeypatch_module.context() as m:
71
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
72
73
74
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
75
76
77
78
79
80
81
82
83
84
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
103
                    },
104
                    "required": ["city", "state", "unit"],
105
106
                },
            },
107
108
        }
    ]
109
110

    messages = [
111
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
112
113
114
    ]

    stream = await gptoss_client.chat.completions.create(
115
116
117
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
118
119
        stream=True,
    )
120
121
122

    name = None
    args_buf = ""
123
    content_buf = ""
124
125
126
127
128
129
130
131
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
132
133
134
135
136
137
138
139
140
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
141
142
143


@pytest.mark.asyncio
144
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
145
146
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
162
                    },
163
                    "required": ["city", "state", "unit"],
164
165
                },
            },
166
167
        }
    ]
168
169

    messages = [
170
171
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
186
    assert not first_msg.content
187
188

    messages.append({"role": "assistant", "content": args1})
189
190
191
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
192
193
194
195
196
197
198
199

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
200
201
202
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
203
204


205
@pytest.mark.asyncio
Chauncey's avatar
Chauncey committed
206
207
208
async def test_gpt_oss_tool_message_array_content(
    gptoss_client: OpenAI, with_tool_parser: bool
):
209
210
211
212
    """Test that tool messages support both string and array content formats."""
    if not with_tool_parser:
        pytest.skip("skip non-tool for array content tests")

Chauncey's avatar
Chauncey committed
213
214
215
216
217
218
219
220
221
222
223
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
224
                    },
Chauncey's avatar
Chauncey committed
225
                    "required": ["city", "state"],
226
227
                },
            },
Chauncey's avatar
Chauncey committed
228
229
        }
    ]
230
231

    # Test 1: Tool message with string content
Chauncey's avatar
Chauncey committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    messages_string = [
        {"role": "user", "content": "What's the weather in Paris?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_123",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Paris", "state": "TX"}',
                    },
                }
            ],
        },
        {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
    ]
249
250
251
252
253
254
255
256
257
258
259
260

    response_string = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_string,
        tools=tools,
        temperature=0.0,
    )

    assert response_string is not None
    assert response_string.choices[0].message is not None

    # Test 2: Tool message with array content
Chauncey's avatar
Chauncey committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    messages_array = [
        {"role": "user", "content": "What's the weather in Dallas?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_456",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Dallas", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
            ],
        },
    ]
283
284
285
286
287
288
289
290
291
292
293
294

    response_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_array is not None
    assert response_array.choices[0].message is not None

    # Test 3: Tool message with multiple array content items
Chauncey's avatar
Chauncey committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    messages_multi_array = [
        {"role": "user", "content": "Search for information"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_789",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Austin", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "Weather data: "},
                {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
                {"type": "text", "text": " with 60% humidity"},
            ],
        },
    ]
319
320
321
322
323
324
325
326
327
328
329
330

    response_multi_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_multi_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_multi_array is not None
    assert response_multi_array.choices[0].message is not None


331
MODEL_NAME = "openai-community/gpt2"
332
MODEL_NAME_SHORT = "gpt2"
333
CHAT_TEMPLATE = "Dummy chat template for testing {}"
334
335
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
336
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
337
]
338
339


340
341
342
343
344
@dataclass
class MockHFConfig:
    model_type: str = "any"


345
346
@dataclass
class MockModelConfig:
347
    task = "generate"
348
    runner_type = "generate"
349
350
351
352
353
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
354
    multimodal_config = MultiModalConfig()
355
    hf_config = MockHFConfig()
356
    logits_processor_pattern = None
357
    diff_sampling_param: dict | None = None
358
    allowed_local_media_path: str = ""
359
    allowed_media_domains: list[str] | None = None
360
    encoder_config = None
361
    generation_config: str = "auto"
362
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
363
    skip_tokenizer_init = False
364
365
366

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
367
368


369
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
392
393
394
395
396
397
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


398
399
@dataclass
class MockEngine:
400
401
402
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
    processor: MagicMock = field(default_factory=MagicMock)
    io_processor: MagicMock = field(default_factory=MagicMock)
403
404
405


async def _async_serving_chat_init():
406
407
    engine = MockEngine()

408
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
409
410
411
412
413
414
415
416
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
417
418
419
420
421
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
422
    assert serving_completion.chat_template == CHAT_TEMPLATE
423
424


425
426
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
427
    mock_engine = MagicMock(spec=AsyncLLM)
428
429
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
430
431
432
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
433

434
    serving_chat = _build_serving_chat(mock_engine)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


455
456
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
457
    mock_engine = MagicMock(spec=AsyncLLM)
458
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
459
    mock_engine.errored = False
460
461
462
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
463

464
    serving_chat = _build_serving_chat(mock_engine)
465

466
467
    req = ChatCompletionRequest(
        model=MODEL_NAME,
468
        messages=[{"role": "user", "content": "what is 1+1?"}],
469
470
471
    )

    with suppress(Exception):
472
        await serving_chat.create_chat_completion(req)
473
474
475
476
477

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
478
        await serving_chat.create_chat_completion(req)
479
480

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
481

482
483
484
485
486
487
488
489
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
490
    mock_engine = MagicMock(spec=AsyncLLM)
491
492
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
493
494
495
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
496
497

    # Initialize the serving chat
498
    serving_chat = _build_serving_chat(mock_engine)
499
500
501
502

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
503
        messages=[{"role": "user", "content": "what is 1+1?"}],
504
505
506
    )

    with suppress(Exception):
507
        await serving_chat.create_chat_completion(req)
508
509
510
511
512
513
514

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
515
        await serving_chat.create_chat_completion(req)
516
517
518
519
520
521
522

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
523
        await serving_chat.create_chat_completion(req)
524
525
526
527
528
529
530
531
532
533
534

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
535
    mock_engine = MagicMock(spec=AsyncLLM)
536
537
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
538
539
540
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
541
542

    # Initialize the serving chat
543
    serving_chat = _build_serving_chat(mock_engine)
544
545
546
547

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
548
        messages=[{"role": "user", "content": "what is 1+1?"}],
549
550
551
    )

    with suppress(Exception):
552
        await serving_chat.create_chat_completion(req)
553
554
555
556
557
558
559

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
560
        await serving_chat.create_chat_completion(req)
561
562
563
564
565
566
567

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
568
        await serving_chat.create_chat_completion(req)
569
570
571

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

572

573
574
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
575
576
577
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
578
        "repetition_penalty": 1.05,
579
580
    }

581
    mock_engine = MagicMock(spec=AsyncLLM)
582
583
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
584
585
586
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
587
588

    # Initialize the serving chat
589
    serving_chat = _build_serving_chat(mock_engine)
590

591
592
    req = ChatCompletionRequest(
        model=MODEL_NAME,
593
        messages=[{"role": "user", "content": "what is 1+1?"}],
594
595
596
    )

    with suppress(Exception):
597
        await serving_chat.create_chat_completion(req)
598
599
600
601
602
603
604
605

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
606
        await serving_chat.create_chat_completion(req)
607
608
609
610
611
612
613
614

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
615
        await serving_chat.create_chat_completion(req)
616
617
618

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
619
620


621
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
622
@pytest.mark.asyncio
623
async def test_serving_chat_did_set_correct_cache_salt(model_type):
624
    mock_model_config = MockModelConfig()
625
    mock_model_config.hf_config.model_type = model_type
626

627
    mock_engine = MagicMock(spec=AsyncLLM)
628
629
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
630
631
632
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
633

634
    serving_chat = _build_serving_chat(mock_engine)
635
636
637
638

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
639
        messages=[{"role": "user", "content": "what is 1+1?"}],
640
641
    )

642
    # By default, cache_salt in the engine prompt is not set
643
    with suppress(Exception):
644
        await serving_chat.create_chat_completion(req)
645
646
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
647
648
649
650

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
651
        await serving_chat.create_chat_completion(req)
652
653
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"