"vllm/vscode:/vscode.git/clone" did not exist on "7042cc96b0a8a154ea165c652d4f63e5be9c291e"
test_serving_chat.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
from contextlib import suppress
5
from dataclasses import dataclass, field
6
from typing import Any
7
from unittest.mock import AsyncMock, MagicMock
8

9
import pytest
10
import pytest_asyncio
11
from openai import OpenAI
12

13
from vllm.config.multimodal import MultiModalConfig
14
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
15
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
16
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
17
from vllm.tokenizers import get_tokenizer
18
from vllm.v1.engine.async_llm import AsyncLLM
19

20
21
22
23
24
25
26
27
from ...utils import RemoteOpenAIServer

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
28

29
30
31
32
33
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


34
35
36
37
38
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
39
40
41
42
def with_tool_parser(request) -> bool:
    return request.param


43
@pytest.fixture(scope="module")
44
45
46
47
48
49
50
51
52
53
54
55
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
56
57
58
59
60
61
62
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
63
64
65
66
    return args


@pytest.fixture(scope="module")
67
68
69
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
70
    with monkeypatch_module.context() as m:
71
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
72
73
74
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
75
76
77
78
79
80
81
82
83
84
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
103
                    },
104
                    "required": ["city", "state", "unit"],
105
106
                },
            },
107
108
        }
    ]
109
110

    messages = [
111
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
112
113
114
    ]

    stream = await gptoss_client.chat.completions.create(
115
116
117
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
118
119
        stream=True,
    )
120
121
122

    name = None
    args_buf = ""
123
    content_buf = ""
124
125
126
127
128
129
130
131
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
132
133
134
135
136
137
138
139
140
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
141
142
143


@pytest.mark.asyncio
144
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
145
146
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
162
                    },
163
                    "required": ["city", "state", "unit"],
164
165
                },
            },
166
167
        }
    ]
168
169

    messages = [
170
171
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
186
    assert not first_msg.content
187
188

    messages.append({"role": "assistant", "content": args1})
189
190
191
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
192
193
194
195
196
197
198
199

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
200
201
202
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
203
204


205
@pytest.mark.asyncio
Chauncey's avatar
Chauncey committed
206
207
208
async def test_gpt_oss_tool_message_array_content(
    gptoss_client: OpenAI, with_tool_parser: bool
):
209
210
211
212
    """Test that tool messages support both string and array content formats."""
    if not with_tool_parser:
        pytest.skip("skip non-tool for array content tests")

Chauncey's avatar
Chauncey committed
213
214
215
216
217
218
219
220
221
222
223
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
224
                    },
Chauncey's avatar
Chauncey committed
225
                    "required": ["city", "state"],
226
227
                },
            },
Chauncey's avatar
Chauncey committed
228
229
        }
    ]
230
231

    # Test 1: Tool message with string content
Chauncey's avatar
Chauncey committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    messages_string = [
        {"role": "user", "content": "What's the weather in Paris?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_123",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Paris", "state": "TX"}',
                    },
                }
            ],
        },
        {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
    ]
249
250
251
252
253
254
255
256
257
258
259
260

    response_string = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_string,
        tools=tools,
        temperature=0.0,
    )

    assert response_string is not None
    assert response_string.choices[0].message is not None

    # Test 2: Tool message with array content
Chauncey's avatar
Chauncey committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    messages_array = [
        {"role": "user", "content": "What's the weather in Dallas?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_456",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Dallas", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
            ],
        },
    ]
283
284
285
286
287
288
289
290
291
292
293
294

    response_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_array is not None
    assert response_array.choices[0].message is not None

    # Test 3: Tool message with multiple array content items
Chauncey's avatar
Chauncey committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    messages_multi_array = [
        {"role": "user", "content": "Search for information"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_789",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Austin", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "Weather data: "},
                {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
                {"type": "text", "text": " with 60% humidity"},
            ],
        },
    ]
319
320
321
322
323
324
325
326
327
328
329
330

    response_multi_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_multi_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_multi_array is not None
    assert response_multi_array.choices[0].message is not None


331
MODEL_NAME = "openai-community/gpt2"
332
MODEL_NAME_SHORT = "gpt2"
333
CHAT_TEMPLATE = "Dummy chat template for testing {}"
334
335
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
336
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
337
]
338
339


340
341
342
343
344
@dataclass
class MockHFConfig:
    model_type: str = "any"


345
346
@dataclass
class MockModelConfig:
347
    task = "generate"
348
    runner_type = "generate"
349
350
    trust_remote_code = False
    max_model_len = 100
351
    multimodal_config = MultiModalConfig()
352
    hf_config = MockHFConfig()
353
    logits_processors: list[str] | None = None
354
    logits_processor_pattern = None
355
    diff_sampling_param: dict | None = None
356
    encoder_config = None
357
    generation_config: str = "auto"
358
359
360

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
361
362


363
364
365
366
367
368
369
370
371
372
373
374
375
@dataclass
class MockRendererConfig:
    model_config: MockModelConfig = field(default_factory=MockModelConfig)

    tokenizer = MODEL_NAME
    tokenizer_mode = "auto"
    tokenizer_revision = None
    skip_tokenizer_init = False
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
    allowed_local_media_path: str = ""
    allowed_media_domains: list[str] | None = None


376
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
399
400
401
402
403
404
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


405
406
@dataclass
class MockEngine:
407
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
408
    renderer_config: MockRendererConfig = field(default_factory=MockRendererConfig)
409
    input_processor: MagicMock = field(default_factory=MagicMock)
410
    io_processor: MagicMock = field(default_factory=MagicMock)
411
412
413


async def _async_serving_chat_init():
414
415
    engine = MockEngine()

416
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
417
418
419
420
421
422
423
424
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
425
426
427
428
429
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
430
    assert serving_completion.chat_template == CHAT_TEMPLATE
431
432


433
434
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
435
    mock_engine = MagicMock(spec=AsyncLLM)
436
437
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
438
    mock_engine.model_config = MockModelConfig()
439
    mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
440
    mock_engine.input_processor = MagicMock()
441
    mock_engine.io_processor = MagicMock()
442

443
    serving_chat = _build_serving_chat(mock_engine)
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


464
465
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
466
    mock_engine = MagicMock(spec=AsyncLLM)
467
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
468
    mock_engine.errored = False
469
    mock_engine.model_config = MockModelConfig()
470
    mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
471
    mock_engine.input_processor = MagicMock()
472
    mock_engine.io_processor = MagicMock()
473

474
    serving_chat = _build_serving_chat(mock_engine)
475

476
477
    req = ChatCompletionRequest(
        model=MODEL_NAME,
478
        messages=[{"role": "user", "content": "what is 1+1?"}],
479
480
481
    )

    with suppress(Exception):
482
        await serving_chat.create_chat_completion(req)
483
484
485
486
487

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
488
        await serving_chat.create_chat_completion(req)
489
490

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
491

492
493
494
495
496
497
498
499
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
500
    mock_engine = MagicMock(spec=AsyncLLM)
501
502
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
503
    mock_engine.model_config = mock_model_config
504
    mock_engine.renderer_config = MockRendererConfig(mock_model_config)
505
    mock_engine.input_processor = MagicMock()
506
    mock_engine.io_processor = MagicMock()
507
508

    # Initialize the serving chat
509
    serving_chat = _build_serving_chat(mock_engine)
510
511
512
513

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
514
        messages=[{"role": "user", "content": "what is 1+1?"}],
515
516
517
    )

    with suppress(Exception):
518
        await serving_chat.create_chat_completion(req)
519
520
521
522
523
524
525

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
526
        await serving_chat.create_chat_completion(req)
527
528
529
530
531
532
533

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
534
        await serving_chat.create_chat_completion(req)
535
536
537
538
539
540
541
542
543
544
545

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
546
    mock_engine = MagicMock(spec=AsyncLLM)
547
548
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
549
    mock_engine.model_config = mock_model_config
550
    mock_engine.renderer_config = MockRendererConfig(mock_model_config)
551
    mock_engine.input_processor = MagicMock()
552
    mock_engine.io_processor = MagicMock()
553
554

    # Initialize the serving chat
555
    serving_chat = _build_serving_chat(mock_engine)
556
557
558
559

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
560
        messages=[{"role": "user", "content": "what is 1+1?"}],
561
562
563
    )

    with suppress(Exception):
564
        await serving_chat.create_chat_completion(req)
565
566
567
568
569
570
571

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
572
        await serving_chat.create_chat_completion(req)
573
574
575
576
577
578
579

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
580
        await serving_chat.create_chat_completion(req)
581
582
583

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

584

585
586
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
587
588
589
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
590
        "repetition_penalty": 1.05,
591
592
    }

593
    mock_engine = MagicMock(spec=AsyncLLM)
594
595
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
596
    mock_engine.model_config = mock_model_config
597
    mock_engine.renderer_config = MockRendererConfig(mock_model_config)
598
    mock_engine.input_processor = MagicMock()
599
    mock_engine.io_processor = MagicMock()
600
601

    # Initialize the serving chat
602
    serving_chat = _build_serving_chat(mock_engine)
603

604
605
    req = ChatCompletionRequest(
        model=MODEL_NAME,
606
        messages=[{"role": "user", "content": "what is 1+1?"}],
607
608
609
    )

    with suppress(Exception):
610
        await serving_chat.create_chat_completion(req)
611
612
613
614
615
616
617
618

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
619
        await serving_chat.create_chat_completion(req)
620
621
622
623
624
625
626
627

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
628
        await serving_chat.create_chat_completion(req)
629
630
631

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
632
633


634
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
635
@pytest.mark.asyncio
636
async def test_serving_chat_did_set_correct_cache_salt(model_type):
637
    mock_model_config = MockModelConfig()
638
    mock_model_config.hf_config.model_type = model_type
639

640
    mock_engine = MagicMock(spec=AsyncLLM)
641
642
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
643
    mock_engine.model_config = mock_model_config
644
    mock_engine.renderer_config = MockRendererConfig(mock_model_config)
645
    mock_engine.input_processor = MagicMock()
646
    mock_engine.io_processor = MagicMock()
647

648
    serving_chat = _build_serving_chat(mock_engine)
649
650
651
652

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
653
        messages=[{"role": "user", "content": "what is 1+1?"}],
654
655
    )

656
    # By default, cache_salt in the engine prompt is not set
657
    with suppress(Exception):
658
        await serving_chat.create_chat_completion(req)
659
660
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
661
662
663
664

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
665
        await serving_chat.create_chat_completion(req)
666
667
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"
668
669
670
671
672
673
674
675
676
677


@pytest.mark.asyncio
async def test_serving_chat_data_parallel_rank_extraction():
    """Test that data_parallel_rank is properly extracted from header and
    passed to engine."""
    mock_engine = MagicMock(spec=AsyncLLM)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
    mock_engine.model_config = MockModelConfig()
678
    mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
679
    mock_engine.input_processor = MagicMock()
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    mock_engine.io_processor = MagicMock()

    # Mock the generate method to return an async generator
    async def mock_generate(*args, **kwargs):
        # Yield a fake RequestOutput
        from vllm.outputs import CompletionOutput, RequestOutput

        yield RequestOutput(
            request_id="test-request",
            prompt="test prompt",
            prompt_token_ids=[1, 2, 3],
            prompt_logprobs=None,
            outputs=[
                CompletionOutput(
                    index=0,
                    text="test response",
                    token_ids=[4, 5, 6],
                    cumulative_logprob=0.0,
                    logprobs=None,
                    finish_reason="stop",
                    stop_reason=None,
                )
            ],
            finished=True,
        )

    mock_engine.generate = AsyncMock(side_effect=mock_generate)

    serving_chat = _build_serving_chat(mock_engine)

    # Test when data_parallel_rank is present in header
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 1+1?"}],
    )

    # Mock request with X-data-parallel-rank header
    mock_raw_request = MagicMock()
    mock_raw_request.headers = {"X-data-parallel-rank": "2"}
    mock_raw_request.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req, mock_raw_request)

    # Verify that data_parallel_rank was passed to engine.generate
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2

    # Test when data_parallel_rank is not present (defaults to None)
    req_no_dp = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 2+2?"}],
    )

    # Mock request with no header
    mock_raw_request_no_dp = MagicMock()
    mock_raw_request_no_dp.headers = {}
    mock_raw_request_no_dp.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)

    # Verify that data_parallel_rank defaults to None
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None