test_serving_chat.py 20.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import asyncio
7
from contextlib import suppress
8
from dataclasses import dataclass, field
9
from typing import TYPE_CHECKING, Any
10
from unittest.mock import AsyncMock, MagicMock
11

12
import pytest
13
import pytest_asyncio
14

15
from vllm.config.multimodal import MultiModalConfig
16
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
17
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
18
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
19
from vllm.transformers_utils.tokenizer import get_tokenizer
20
from vllm.v1.engine.async_llm import AsyncLLM
21

22
23
24
25
26
27
28
29
30
31
32
from ...utils import RemoteOpenAIServer

if TYPE_CHECKING:
    from openai import OpenAI

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
33

34
35
36
37
38
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


39
40
41
42
43
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
44
45
46
47
def with_tool_parser(request) -> bool:
    return request.param


48
@pytest.fixture(scope="module")
49
50
51
52
53
54
55
56
57
58
59
60
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
61
62
63
64
65
66
67
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
68
69
70
71
    return args


@pytest.fixture(scope="module")
72
73
74
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
75
    with monkeypatch_module.context() as m:
76
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
77
78
79
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
80
81
82
83
84
85
86
87
88
89
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
108
                    },
109
                    "required": ["city", "state", "unit"],
110
111
                },
            },
112
113
        }
    ]
114
115

    messages = [
116
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
117
118
119
    ]

    stream = await gptoss_client.chat.completions.create(
120
121
122
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
123
124
        stream=True,
    )
125
126
127

    name = None
    args_buf = ""
128
    content_buf = ""
129
130
131
132
133
134
135
136
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
137
138
139
140
141
142
143
144
145
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
146
147
148


@pytest.mark.asyncio
149
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
150
151
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
167
                    },
168
                    "required": ["city", "state", "unit"],
169
170
                },
            },
171
172
        }
    ]
173
174

    messages = [
175
176
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
191
    assert not first_msg.content
192
193

    messages.append({"role": "assistant", "content": args1})
194
195
196
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
197
198
199
200
201
202
203
204

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
205
206
207
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
208
209


210
@pytest.mark.asyncio
Chauncey's avatar
Chauncey committed
211
212
213
async def test_gpt_oss_tool_message_array_content(
    gptoss_client: OpenAI, with_tool_parser: bool
):
214
215
216
217
    """Test that tool messages support both string and array content formats."""
    if not with_tool_parser:
        pytest.skip("skip non-tool for array content tests")

Chauncey's avatar
Chauncey committed
218
219
220
221
222
223
224
225
226
227
228
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
229
                    },
Chauncey's avatar
Chauncey committed
230
                    "required": ["city", "state"],
231
232
                },
            },
Chauncey's avatar
Chauncey committed
233
234
        }
    ]
235
236

    # Test 1: Tool message with string content
Chauncey's avatar
Chauncey committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    messages_string = [
        {"role": "user", "content": "What's the weather in Paris?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_123",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Paris", "state": "TX"}',
                    },
                }
            ],
        },
        {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
    ]
254
255
256
257
258
259
260
261
262
263
264
265

    response_string = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_string,
        tools=tools,
        temperature=0.0,
    )

    assert response_string is not None
    assert response_string.choices[0].message is not None

    # Test 2: Tool message with array content
Chauncey's avatar
Chauncey committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    messages_array = [
        {"role": "user", "content": "What's the weather in Dallas?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_456",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Dallas", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
            ],
        },
    ]
288
289
290
291
292
293
294
295
296
297
298
299

    response_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_array is not None
    assert response_array.choices[0].message is not None

    # Test 3: Tool message with multiple array content items
Chauncey's avatar
Chauncey committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    messages_multi_array = [
        {"role": "user", "content": "Search for information"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_789",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Austin", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "Weather data: "},
                {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
                {"type": "text", "text": " with 60% humidity"},
            ],
        },
    ]
324
325
326
327
328
329
330
331
332
333
334
335

    response_multi_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_multi_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_multi_array is not None
    assert response_multi_array.choices[0].message is not None


336
MODEL_NAME = "openai-community/gpt2"
337
MODEL_NAME_SHORT = "gpt2"
338
CHAT_TEMPLATE = "Dummy chat template for testing {}"
339
340
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
341
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
342
]
343
344


345
346
347
348
349
@dataclass
class MockHFConfig:
    model_type: str = "any"


350
351
@dataclass
class MockModelConfig:
352
    task = "generate"
353
    runner_type = "generate"
354
355
356
357
358
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
359
    multimodal_config = MultiModalConfig()
360
    hf_config = MockHFConfig()
361
    logits_processor_pattern = None
362
    diff_sampling_param: dict | None = None
363
    allowed_local_media_path: str = ""
364
    allowed_media_domains: list[str] | None = None
365
    encoder_config = None
366
    generation_config: str = "auto"
367
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
368
    skip_tokenizer_init = False
369
370
371

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
372
373


374
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
397
398
399
400
401
402
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


403
404
@dataclass
class MockEngine:
405
406
407
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
    processor: MagicMock = field(default_factory=MagicMock)
    io_processor: MagicMock = field(default_factory=MagicMock)
408
409
410


async def _async_serving_chat_init():
411
412
    engine = MockEngine()

413
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
414
415
416
417
418
419
420
421
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
422
423
424
425
426
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
427
    assert serving_completion.chat_template == CHAT_TEMPLATE
428
429


430
431
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
432
    mock_engine = MagicMock(spec=AsyncLLM)
433
434
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
435
436
437
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
438

439
    serving_chat = _build_serving_chat(mock_engine)
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


460
461
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
462
    mock_engine = MagicMock(spec=AsyncLLM)
463
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
464
    mock_engine.errored = False
465
466
467
    mock_engine.model_config = MockModelConfig()
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
468

469
    serving_chat = _build_serving_chat(mock_engine)
470

471
472
    req = ChatCompletionRequest(
        model=MODEL_NAME,
473
        messages=[{"role": "user", "content": "what is 1+1?"}],
474
475
476
    )

    with suppress(Exception):
477
        await serving_chat.create_chat_completion(req)
478
479
480
481
482

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
483
        await serving_chat.create_chat_completion(req)
484
485

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
486

487
488
489
490
491
492
493
494
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
495
    mock_engine = MagicMock(spec=AsyncLLM)
496
497
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
498
499
500
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
501
502

    # Initialize the serving chat
503
    serving_chat = _build_serving_chat(mock_engine)
504
505
506
507

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
508
        messages=[{"role": "user", "content": "what is 1+1?"}],
509
510
511
    )

    with suppress(Exception):
512
        await serving_chat.create_chat_completion(req)
513
514
515
516
517
518
519

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
520
        await serving_chat.create_chat_completion(req)
521
522
523
524
525
526
527

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
528
        await serving_chat.create_chat_completion(req)
529
530
531
532
533
534
535
536
537
538
539

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
540
    mock_engine = MagicMock(spec=AsyncLLM)
541
542
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
543
544
545
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
546
547

    # Initialize the serving chat
548
    serving_chat = _build_serving_chat(mock_engine)
549
550
551
552

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
553
        messages=[{"role": "user", "content": "what is 1+1?"}],
554
555
556
    )

    with suppress(Exception):
557
        await serving_chat.create_chat_completion(req)
558
559
560
561
562
563
564

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
565
        await serving_chat.create_chat_completion(req)
566
567
568
569
570
571
572

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
573
        await serving_chat.create_chat_completion(req)
574
575
576

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

577

578
579
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
580
581
582
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
583
        "repetition_penalty": 1.05,
584
585
    }

586
    mock_engine = MagicMock(spec=AsyncLLM)
587
588
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
589
590
591
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
592
593

    # Initialize the serving chat
594
    serving_chat = _build_serving_chat(mock_engine)
595

596
597
    req = ChatCompletionRequest(
        model=MODEL_NAME,
598
        messages=[{"role": "user", "content": "what is 1+1?"}],
599
600
601
    )

    with suppress(Exception):
602
        await serving_chat.create_chat_completion(req)
603
604
605
606
607
608
609
610

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
611
        await serving_chat.create_chat_completion(req)
612
613
614
615
616
617
618
619

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
620
        await serving_chat.create_chat_completion(req)
621
622
623

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
624
625


626
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
627
@pytest.mark.asyncio
628
async def test_serving_chat_did_set_correct_cache_salt(model_type):
629
    mock_model_config = MockModelConfig()
630
    mock_model_config.hf_config.model_type = model_type
631

632
    mock_engine = MagicMock(spec=AsyncLLM)
633
634
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
635
636
637
    mock_engine.model_config = mock_model_config
    mock_engine.processor = MagicMock()
    mock_engine.io_processor = MagicMock()
638

639
    serving_chat = _build_serving_chat(mock_engine)
640
641
642
643

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
644
        messages=[{"role": "user", "content": "what is 1+1?"}],
645
646
    )

647
    # By default, cache_salt in the engine prompt is not set
648
    with suppress(Exception):
649
        await serving_chat.create_chat_completion(req)
650
651
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
652
653
654
655

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
656
        await serving_chat.create_chat_completion(req)
657
658
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"