test_serving_chat.py 23.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
zhuwenwen's avatar
zhuwenwen committed
3
import os
4

5
import asyncio
6
from contextlib import suppress
7
from dataclasses import dataclass, field
8
from typing import Any
9
from unittest.mock import AsyncMock, MagicMock
10

11
import pytest
12
import pytest_asyncio
13
from openai import OpenAI
14

15
from vllm.config.multimodal import MultiModalConfig
16
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
17
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
18
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
19
from vllm.tokenizers import get_tokenizer
20
from vllm.v1.engine.async_llm import AsyncLLM
21

22
from ...utils import RemoteOpenAIServer, models_path_prefix
23

24
GPT_OSS_MODEL_NAME = os.path.join(models_path_prefix, "openai/gpt-oss-20b")
25
26
27
28
29


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
30

31
32
33
34
35
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


36
37
38
39
40
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
41
42
43
44
def with_tool_parser(request) -> bool:
    return request.param


45
@pytest.fixture(scope="module")
46
47
48
49
50
51
52
53
54
55
56
57
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
58
59
60
61
62
63
64
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
65
66
67
68
    return args


@pytest.fixture(scope="module")
69
70
71
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
72
    with monkeypatch_module.context() as m:
73
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
74
75
76
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
77
78
79
80
81
82
83
84
85
86
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
105
                    },
106
                    "required": ["city", "state", "unit"],
107
108
                },
            },
109
110
        }
    ]
111
112

    messages = [
113
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
114
115
116
    ]

    stream = await gptoss_client.chat.completions.create(
117
118
119
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
120
121
        stream=True,
    )
122
123
124

    name = None
    args_buf = ""
125
    content_buf = ""
126
127
128
129
130
131
132
133
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
134
135
136
137
138
139
140
141
142
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
143
144
145


@pytest.mark.asyncio
146
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
147
148
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
164
                    },
165
                    "required": ["city", "state", "unit"],
166
167
                },
            },
168
169
        }
    ]
170
171

    messages = [
172
173
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
188
    assert not first_msg.content
189
190

    messages.append({"role": "assistant", "content": args1})
191
192
193
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
194
195
196
197
198
199
200
201

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
202
203
204
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
205
206


207
@pytest.mark.asyncio
Chauncey's avatar
Chauncey committed
208
209
210
async def test_gpt_oss_tool_message_array_content(
    gptoss_client: OpenAI, with_tool_parser: bool
):
211
212
213
214
    """Test that tool messages support both string and array content formats."""
    if not with_tool_parser:
        pytest.skip("skip non-tool for array content tests")

Chauncey's avatar
Chauncey committed
215
216
217
218
219
220
221
222
223
224
225
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
226
                    },
Chauncey's avatar
Chauncey committed
227
                    "required": ["city", "state"],
228
229
                },
            },
Chauncey's avatar
Chauncey committed
230
231
        }
    ]
232
233

    # Test 1: Tool message with string content
Chauncey's avatar
Chauncey committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    messages_string = [
        {"role": "user", "content": "What's the weather in Paris?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_123",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Paris", "state": "TX"}',
                    },
                }
            ],
        },
        {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
    ]
251
252
253
254
255
256
257
258
259
260
261
262

    response_string = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_string,
        tools=tools,
        temperature=0.0,
    )

    assert response_string is not None
    assert response_string.choices[0].message is not None

    # Test 2: Tool message with array content
Chauncey's avatar
Chauncey committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    messages_array = [
        {"role": "user", "content": "What's the weather in Dallas?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_456",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Dallas", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
            ],
        },
    ]
285
286
287
288
289
290
291
292
293
294
295
296

    response_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_array is not None
    assert response_array.choices[0].message is not None

    # Test 3: Tool message with multiple array content items
Chauncey's avatar
Chauncey committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    messages_multi_array = [
        {"role": "user", "content": "Search for information"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_789",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Austin", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "Weather data: "},
                {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
                {"type": "text", "text": " with 60% humidity"},
            ],
        },
    ]
321
322
323
324
325
326
327
328
329
330

    response_multi_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_multi_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_multi_array is not None
    assert response_multi_array.choices[0].message is not None
331

332

333
MODEL_NAME = os.path.join(models_path_prefix, "openai-community/gpt2")
334
MODEL_NAME_SHORT = os.path.join(models_path_prefix, "gpt2")
335
CHAT_TEMPLATE = "Dummy chat template for testing {}"
336
337
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
338
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
339
]
340
341


342
343
344
345
346
@dataclass
class MockHFConfig:
    model_type: str = "any"


347
348
@dataclass
class MockModelConfig:
349
    task = "generate"
350
    runner_type = "generate"
351
352
353
354
355
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
356
    multimodal_config = MultiModalConfig()
357
    hf_config = MockHFConfig()
358
    logits_processors: list[str] | None = None
359
    logits_processor_pattern = None
360
    diff_sampling_param: dict | None = None
361
    allowed_local_media_path: str = ""
362
    allowed_media_domains: list[str] | None = None
363
    encoder_config = None
364
    generation_config: str = "auto"
365
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
366
    skip_tokenizer_init = False
367
368
369

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
370
371


372
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
395
396
397
398
399
400
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


401
402
@dataclass
class MockEngine:
403
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
404
    input_processor: MagicMock = field(default_factory=MagicMock)
405
    io_processor: MagicMock = field(default_factory=MagicMock)
406
407
408


async def _async_serving_chat_init():
409
410
    engine = MockEngine()

411
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
412
413
414
415
416
417
418
419
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
420
421
422
423
424
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
425
    assert serving_completion.chat_template == CHAT_TEMPLATE
426
427


428
429
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
430
    mock_engine = MagicMock(spec=AsyncLLM)
431
432
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
433
    mock_engine.model_config = MockModelConfig()
434
    mock_engine.input_processor = MagicMock()
435
    mock_engine.io_processor = MagicMock()
436

437
    serving_chat = _build_serving_chat(mock_engine)
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


458
459
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
460
    mock_engine = MagicMock(spec=AsyncLLM)
461
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
462
    mock_engine.errored = False
463
    mock_engine.model_config = MockModelConfig()
464
    mock_engine.input_processor = MagicMock()
465
    mock_engine.io_processor = MagicMock()
466

467
    serving_chat = _build_serving_chat(mock_engine)
468

469
470
    req = ChatCompletionRequest(
        model=MODEL_NAME,
471
        messages=[{"role": "user", "content": "what is 1+1?"}],
472
473
474
    )

    with suppress(Exception):
475
        await serving_chat.create_chat_completion(req)
476
477
478
479
480

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
481
        await serving_chat.create_chat_completion(req)
482
483

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
484

485
486
487
488
489
490
491
492
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
493
    mock_engine = MagicMock(spec=AsyncLLM)
494
495
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
496
    mock_engine.model_config = mock_model_config
497
    mock_engine.input_processor = MagicMock()
498
    mock_engine.io_processor = MagicMock()
499
500

    # Initialize the serving chat
501
    serving_chat = _build_serving_chat(mock_engine)
502
503
504
505

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
506
        messages=[{"role": "user", "content": "what is 1+1?"}],
507
508
509
    )

    with suppress(Exception):
510
        await serving_chat.create_chat_completion(req)
511
512
513
514
515
516
517

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
518
        await serving_chat.create_chat_completion(req)
519
520
521
522
523
524
525

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
526
        await serving_chat.create_chat_completion(req)
527
528
529
530
531
532
533
534
535
536
537

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
538
    mock_engine = MagicMock(spec=AsyncLLM)
539
540
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
541
    mock_engine.model_config = mock_model_config
542
    mock_engine.input_processor = MagicMock()
543
    mock_engine.io_processor = MagicMock()
544
545

    # Initialize the serving chat
546
    serving_chat = _build_serving_chat(mock_engine)
547
548
549
550

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
551
        messages=[{"role": "user", "content": "what is 1+1?"}],
552
553
554
    )

    with suppress(Exception):
555
        await serving_chat.create_chat_completion(req)
556
557
558
559
560
561
562

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
563
        await serving_chat.create_chat_completion(req)
564
565
566
567
568
569
570

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
571
        await serving_chat.create_chat_completion(req)
572
573
574

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

575

576
577
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
578
579
580
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
581
        "repetition_penalty": 1.05,
582
583
    }

584
    mock_engine = MagicMock(spec=AsyncLLM)
585
586
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
587
    mock_engine.model_config = mock_model_config
588
    mock_engine.input_processor = MagicMock()
589
    mock_engine.io_processor = MagicMock()
590
591

    # Initialize the serving chat
592
    serving_chat = _build_serving_chat(mock_engine)
593

594
595
    req = ChatCompletionRequest(
        model=MODEL_NAME,
596
        messages=[{"role": "user", "content": "what is 1+1?"}],
597
598
599
    )

    with suppress(Exception):
600
        await serving_chat.create_chat_completion(req)
601
602
603
604
605
606
607
608

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
609
        await serving_chat.create_chat_completion(req)
610
611
612
613
614
615
616
617

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
618
        await serving_chat.create_chat_completion(req)
619
620
621

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
622
623


624
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
625
@pytest.mark.asyncio
626
async def test_serving_chat_did_set_correct_cache_salt(model_type):
627
    mock_model_config = MockModelConfig()
628
    mock_model_config.hf_config.model_type = model_type
629

630
    mock_engine = MagicMock(spec=AsyncLLM)
631
632
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
633
    mock_engine.model_config = mock_model_config
634
    mock_engine.input_processor = MagicMock()
635
    mock_engine.io_processor = MagicMock()
636

637
    serving_chat = _build_serving_chat(mock_engine)
638
639
640
641

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
642
        messages=[{"role": "user", "content": "what is 1+1?"}],
643
644
    )

645
    # By default, cache_salt in the engine prompt is not set
646
    with suppress(Exception):
647
        await serving_chat.create_chat_completion(req)
648
649
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
650
651
652
653

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
654
        await serving_chat.create_chat_completion(req)
655
656
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"
657
658
659
660
661
662
663
664
665
666


@pytest.mark.asyncio
async def test_serving_chat_data_parallel_rank_extraction():
    """Test that data_parallel_rank is properly extracted from header and
    passed to engine."""
    mock_engine = MagicMock(spec=AsyncLLM)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
    mock_engine.model_config = MockModelConfig()
667
    mock_engine.input_processor = MagicMock()
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    mock_engine.io_processor = MagicMock()

    # Mock the generate method to return an async generator
    async def mock_generate(*args, **kwargs):
        # Yield a fake RequestOutput
        from vllm.outputs import CompletionOutput, RequestOutput

        yield RequestOutput(
            request_id="test-request",
            prompt="test prompt",
            prompt_token_ids=[1, 2, 3],
            prompt_logprobs=None,
            outputs=[
                CompletionOutput(
                    index=0,
                    text="test response",
                    token_ids=[4, 5, 6],
                    cumulative_logprob=0.0,
                    logprobs=None,
                    finish_reason="stop",
                    stop_reason=None,
                )
            ],
            finished=True,
        )

    mock_engine.generate = AsyncMock(side_effect=mock_generate)

    serving_chat = _build_serving_chat(mock_engine)

    # Test when data_parallel_rank is present in header
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 1+1?"}],
    )

    # Mock request with X-data-parallel-rank header
    mock_raw_request = MagicMock()
    mock_raw_request.headers = {"X-data-parallel-rank": "2"}
    mock_raw_request.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req, mock_raw_request)

    # Verify that data_parallel_rank was passed to engine.generate
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2

    # Test when data_parallel_rank is not present (defaults to None)
    req_no_dp = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 2+2?"}],
    )

    # Mock request with no header
    mock_raw_request_no_dp = MagicMock()
    mock_raw_request_no_dp.headers = {}
    mock_raw_request_no_dp.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)

    # Verify that data_parallel_rank defaults to None
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None