test_serving_chat.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
from contextlib import suppress
5
from dataclasses import dataclass, field
6
from typing import Any
7
from unittest.mock import AsyncMock, MagicMock
8

9
import pytest
10
import pytest_asyncio
11
from openai import OpenAI
12

13
from vllm.config.multimodal import MultiModalConfig
14
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
15
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
16
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
17
from vllm.tokenizers import get_tokenizer
18
from vllm.v1.engine.async_llm import AsyncLLM
19

20
21
22
23
24
25
26
27
from ...utils import RemoteOpenAIServer

GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"


@pytest.fixture(scope="module")
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
28

29
30
31
32
33
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


34
35
36
37
38
@pytest.fixture(
    scope="module",
    params=[True, False],
    ids=["with_tool_parser", "without_tool_parser"],
)
39
40
41
42
def with_tool_parser(request) -> bool:
    return request.param


43
@pytest.fixture(scope="module")
44
45
46
47
48
49
50
51
52
53
54
55
def default_server_args(with_tool_parser: bool):
    args = [
        # use half precision for speed and memory savings in CI environment
        "--enforce-eager",
        "--max-model-len",
        "4096",
        "--reasoning-parser",
        "openai_gptoss",
        "--gpu-memory-utilization",
        "0.8",
    ]
    if with_tool_parser:
56
57
58
59
60
61
62
        args.extend(
            [
                "--tool-call-parser",
                "openai",
                "--enable-auto-tool-choice",
            ]
        )
63
64
65
66
    return args


@pytest.fixture(scope="module")
67
68
69
def gptoss_server(
    monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
):
70
    with monkeypatch_module.context() as m:
71
        m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
72
73
74
        with RemoteOpenAIServer(
            GPT_OSS_MODEL_NAME, default_server_args
        ) as remote_server:
75
76
77
78
79
80
81
82
83
84
            yield remote_server


@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
    async with gptoss_server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
async def test_gpt_oss_chat_tool_call_streaming(
    gptoss_client: OpenAI, with_tool_parser: bool
):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
103
                    },
104
                    "required": ["city", "state", "unit"],
105
106
                },
            },
107
108
        }
    ]
109
110

    messages = [
111
        {"role": "user", "content": "What is the weather in Dallas, TX?"},
112
113
114
    ]

    stream = await gptoss_client.chat.completions.create(
115
116
117
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools if with_tool_parser else None,
118
119
        stream=True,
    )
120
121
122

    name = None
    args_buf = ""
123
    content_buf = ""
124
125
126
127
128
129
130
131
    async for chunk in stream:
        delta = chunk.choices[0].delta
        if delta.tool_calls:
            tc = delta.tool_calls[0]
            if tc.function and tc.function.name:
                name = tc.function.name
            if tc.function and tc.function.arguments:
                args_buf += tc.function.arguments
132
133
134
135
136
137
138
139
140
        if getattr(delta, "content", None):
            content_buf += delta.content
    if with_tool_parser:
        assert name is not None
        assert len(args_buf) > 0
    else:
        assert name is None
        assert len(args_buf) == 0
        assert len(content_buf) > 0
141
142
143


@pytest.mark.asyncio
144
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
145
146
    if not with_tool_parser:
        pytest.skip("skip non-tool for multi-turn tests")
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                        },
162
                    },
163
                    "required": ["city", "state", "unit"],
164
165
                },
            },
166
167
        }
    ]
168
169

    messages = [
170
171
        {"role": "system", "content": "you are a helpful assistant"},
        {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    ]

    first = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    first_msg = first.choices[0].message
    assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
    tc = first_msg.tool_calls[0]
    assert tc.function is not None and tc.function.name == "get_current_weather"
    args1 = tc.function.arguments
    assert args1 is not None and len(args1) > 0
186
    assert not first_msg.content
187
188

    messages.append({"role": "assistant", "content": args1})
189
190
191
    messages.append(
        {"role": "user", "content": "Now convert to celsius and return JSON only"}
    )
192
193
194
195
196
197
198
199

    second = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages,
        tools=tools,
        temperature=0.0,
    )
    second_msg = second.choices[0].message
200
201
202
    assert (second_msg.content is not None and len(second_msg.content) > 0) or (
        second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
    )
203
204


205
@pytest.mark.asyncio
Chauncey's avatar
Chauncey committed
206
207
208
async def test_gpt_oss_tool_message_array_content(
    gptoss_client: OpenAI, with_tool_parser: bool
):
209
210
211
212
    """Test that tool messages support both string and array content formats."""
    if not with_tool_parser:
        pytest.skip("skip non-tool for array content tests")

Chauncey's avatar
Chauncey committed
213
214
215
216
217
218
219
220
221
222
223
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string"},
                        "state": {"type": "string"},
224
                    },
Chauncey's avatar
Chauncey committed
225
                    "required": ["city", "state"],
226
227
                },
            },
Chauncey's avatar
Chauncey committed
228
229
        }
    ]
230
231

    # Test 1: Tool message with string content
Chauncey's avatar
Chauncey committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    messages_string = [
        {"role": "user", "content": "What's the weather in Paris?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_123",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Paris", "state": "TX"}',
                    },
                }
            ],
        },
        {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
    ]
249
250
251
252
253
254
255
256
257
258
259
260

    response_string = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_string,
        tools=tools,
        temperature=0.0,
    )

    assert response_string is not None
    assert response_string.choices[0].message is not None

    # Test 2: Tool message with array content
Chauncey's avatar
Chauncey committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    messages_array = [
        {"role": "user", "content": "What's the weather in Dallas?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_456",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Dallas", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
            ],
        },
    ]
283
284
285
286
287
288
289
290
291
292
293
294

    response_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_array is not None
    assert response_array.choices[0].message is not None

    # Test 3: Tool message with multiple array content items
Chauncey's avatar
Chauncey committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    messages_multi_array = [
        {"role": "user", "content": "Search for information"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "call_789",
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "arguments": '{"city": "Austin", "state": "TX"}',
                    },
                }
            ],
        },
        {
            "role": "tool",
            "content": [
                {"type": "text", "text": "Weather data: "},
                {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
                {"type": "text", "text": " with 60% humidity"},
            ],
        },
    ]
319
320
321
322
323
324
325
326
327
328
329
330

    response_multi_array = await gptoss_client.chat.completions.create(
        model=GPT_OSS_MODEL_NAME,
        messages=messages_multi_array,
        tools=tools,
        temperature=0.0,
    )

    assert response_multi_array is not None
    assert response_multi_array.choices[0].message is not None


331
MODEL_NAME = "openai-community/gpt2"
332
MODEL_NAME_SHORT = "gpt2"
333
CHAT_TEMPLATE = "Dummy chat template for testing {}"
334
335
BASE_MODEL_PATHS = [
    BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
336
    BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
337
]
338
339


340
341
342
343
344
@dataclass
class MockHFConfig:
    model_type: str = "any"


345
346
@dataclass
class MockModelConfig:
347
    task = "generate"
348
    runner_type = "generate"
349
    tokenizer = MODEL_NAME
350
    trust_remote_code = False
351
    tokenizer_mode = "auto"
352
    max_model_len = 100
353
    tokenizer_revision = None
354
    multimodal_config = MultiModalConfig()
355
    hf_config = MockHFConfig()
356
    logits_processors: list[str] | None = None
357
    logits_processor_pattern = None
358
    diff_sampling_param: dict | None = None
359
360
    allowed_local_media_path: str = ""
    allowed_media_domains: list[str] | None = None
361
    encoder_config = None
362
    generation_config: str = "auto"
363
364
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
    skip_tokenizer_init = False
365
366
367

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
368
369


370
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    models = OpenAIServingModels(
        engine_client=engine,
        base_model_paths=BASE_MODEL_PATHS,
    )
    serving_chat = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )

    async def _fake_process_inputs(
        request_id,
        engine_prompt,
        sampling_params,
        *,
        lora_request,
        trace_headers,
        priority,
    ):
393
394
395
396
397
398
        return dict(engine_prompt), {}

    serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
    return serving_chat


399
400
@dataclass
class MockEngine:
401
    model_config: MockModelConfig = field(default_factory=MockModelConfig)
402
    input_processor: MagicMock = field(default_factory=MagicMock)
403
    io_processor: MagicMock = field(default_factory=MagicMock)
404
405
406


async def _async_serving_chat_init():
407
408
    engine = MockEngine()

409
    models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
410
411
412
413
414
415
416
417
    serving_completion = OpenAIServingChat(
        engine,
        models,
        response_role="assistant",
        chat_template=CHAT_TEMPLATE,
        chat_template_content_format="auto",
        request_logger=None,
    )
418
419
420
421
422
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
423
    assert serving_completion.chat_template == CHAT_TEMPLATE
424
425


426
427
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
428
    mock_engine = MagicMock(spec=AsyncLLM)
429
430
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
431
    mock_engine.model_config = MockModelConfig()
432
    mock_engine.input_processor = MagicMock()
433
    mock_engine.io_processor = MagicMock()
434

435
    serving_chat = _build_serving_chat(mock_engine)
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    messages = [{"role": "user", "content": "what is 1+1?"}]

    async def return_model_name(*args):
        return args[3]

    serving_chat.chat_completion_full_generator = return_model_name

    # Test that full name is returned when short name is requested
    req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when empty string is specified
    req = ChatCompletionRequest(model="", messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME

    # Test that full name is returned when no model is specified
    req = ChatCompletionRequest(messages=messages)
    assert await serving_chat.create_chat_completion(req) == MODEL_NAME


456
457
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
458
    mock_engine = MagicMock(spec=AsyncLLM)
459
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
460
    mock_engine.errored = False
461
    mock_engine.model_config = MockModelConfig()
462
    mock_engine.input_processor = MagicMock()
463
    mock_engine.io_processor = MagicMock()
464

465
    serving_chat = _build_serving_chat(mock_engine)
466

467
468
    req = ChatCompletionRequest(
        model=MODEL_NAME,
469
        messages=[{"role": "user", "content": "what is 1+1?"}],
470
471
472
    )

    with suppress(Exception):
473
        await serving_chat.create_chat_completion(req)
474
475
476
477
478

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
479
        await serving_chat.create_chat_completion(req)
480
481

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
482

483
484
485
486
487
488
489
490
    # Setting server's max_tokens in the generation_config.json
    # lower than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 10  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
491
    mock_engine = MagicMock(spec=AsyncLLM)
492
493
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
494
    mock_engine.model_config = mock_model_config
495
    mock_engine.input_processor = MagicMock()
496
    mock_engine.io_processor = MagicMock()
497
498

    # Initialize the serving chat
499
    serving_chat = _build_serving_chat(mock_engine)
500
501
502
503

    # Test Case 1: No max_tokens specified in request
    req = ChatCompletionRequest(
        model=MODEL_NAME,
504
        messages=[{"role": "user", "content": "what is 1+1?"}],
505
506
507
    )

    with suppress(Exception):
508
        await serving_chat.create_chat_completion(req)
509
510
511
512
513
514
515

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 15

    with suppress(Exception):
516
        await serving_chat.create_chat_completion(req)
517
518
519
520
521
522
523

    assert mock_engine.generate.call_args.args[1].max_tokens == 10

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
524
        await serving_chat.create_chat_completion(req)
525
526
527
528
529
530
531
532
533
534
535

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

    # Setting server's max_tokens in the generation_config.json
    # higher than context_window - prompt_tokens
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "max_tokens": 200  # Setting server-side max_tokens limit
    }

    # Reinitialize the engine with new settings
536
    mock_engine = MagicMock(spec=AsyncLLM)
537
538
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
539
    mock_engine.model_config = mock_model_config
540
    mock_engine.input_processor = MagicMock()
541
    mock_engine.io_processor = MagicMock()
542
543

    # Initialize the serving chat
544
    serving_chat = _build_serving_chat(mock_engine)
545
546
547
548

    # Test case 1: No max_tokens specified, defaults to context_window
    req = ChatCompletionRequest(
        model=MODEL_NAME,
549
        messages=[{"role": "user", "content": "what is 1+1?"}],
550
551
552
    )

    with suppress(Exception):
553
        await serving_chat.create_chat_completion(req)
554
555
556
557
558
559
560

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 2: Request's max_tokens set higher than server accepts
    req.max_tokens = 100

    with suppress(Exception):
561
        await serving_chat.create_chat_completion(req)
562
563
564
565
566
567
568

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    # Test Case 3: Request's max_tokens set lower than server accepts
    req.max_tokens = 5

    with suppress(Exception):
569
        await serving_chat.create_chat_completion(req)
570
571
572

    assert mock_engine.generate.call_args.args[1].max_tokens == 5

573

574
575
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
576
577
578
    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
579
        "repetition_penalty": 1.05,
580
581
    }

582
    mock_engine = MagicMock(spec=AsyncLLM)
583
584
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
585
    mock_engine.model_config = mock_model_config
586
    mock_engine.input_processor = MagicMock()
587
    mock_engine.io_processor = MagicMock()
588
589

    # Initialize the serving chat
590
    serving_chat = _build_serving_chat(mock_engine)
591

592
593
    req = ChatCompletionRequest(
        model=MODEL_NAME,
594
        messages=[{"role": "user", "content": "what is 1+1?"}],
595
596
597
    )

    with suppress(Exception):
598
        await serving_chat.create_chat_completion(req)
599
600
601
602
603
604
605
606

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
607
        await serving_chat.create_chat_completion(req)
608
609
610
611
612
613
614
615

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
616
        await serving_chat.create_chat_completion(req)
617
618
619

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
620
621


622
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
623
@pytest.mark.asyncio
624
async def test_serving_chat_did_set_correct_cache_salt(model_type):
625
    mock_model_config = MockModelConfig()
626
    mock_model_config.hf_config.model_type = model_type
627

628
    mock_engine = MagicMock(spec=AsyncLLM)
629
630
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
631
    mock_engine.model_config = mock_model_config
632
    mock_engine.input_processor = MagicMock()
633
    mock_engine.io_processor = MagicMock()
634

635
    serving_chat = _build_serving_chat(mock_engine)
636
637
638
639

    # Test cache_salt
    req = ChatCompletionRequest(
        model=MODEL_NAME,
640
        messages=[{"role": "user", "content": "what is 1+1?"}],
641
642
    )

643
    # By default, cache_salt in the engine prompt is not set
644
    with suppress(Exception):
645
        await serving_chat.create_chat_completion(req)
646
647
    engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
    assert "cache_salt" not in engine_prompt
648
649
650
651

    # Test with certain cache_salt
    req.cache_salt = "test_salt"
    with suppress(Exception):
652
        await serving_chat.create_chat_completion(req)
653
654
    engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
    assert engine_prompt.get("cache_salt") == "test_salt"
655
656
657
658
659
660
661
662
663
664


@pytest.mark.asyncio
async def test_serving_chat_data_parallel_rank_extraction():
    """Test that data_parallel_rank is properly extracted from header and
    passed to engine."""
    mock_engine = MagicMock(spec=AsyncLLM)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False
    mock_engine.model_config = MockModelConfig()
665
    mock_engine.input_processor = MagicMock()
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    mock_engine.io_processor = MagicMock()

    # Mock the generate method to return an async generator
    async def mock_generate(*args, **kwargs):
        # Yield a fake RequestOutput
        from vllm.outputs import CompletionOutput, RequestOutput

        yield RequestOutput(
            request_id="test-request",
            prompt="test prompt",
            prompt_token_ids=[1, 2, 3],
            prompt_logprobs=None,
            outputs=[
                CompletionOutput(
                    index=0,
                    text="test response",
                    token_ids=[4, 5, 6],
                    cumulative_logprob=0.0,
                    logprobs=None,
                    finish_reason="stop",
                    stop_reason=None,
                )
            ],
            finished=True,
        )

    mock_engine.generate = AsyncMock(side_effect=mock_generate)

    serving_chat = _build_serving_chat(mock_engine)

    # Test when data_parallel_rank is present in header
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 1+1?"}],
    )

    # Mock request with X-data-parallel-rank header
    mock_raw_request = MagicMock()
    mock_raw_request.headers = {"X-data-parallel-rank": "2"}
    mock_raw_request.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req, mock_raw_request)

    # Verify that data_parallel_rank was passed to engine.generate
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2

    # Test when data_parallel_rank is not present (defaults to None)
    req_no_dp = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": "what is 2+2?"}],
    )

    # Mock request with no header
    mock_raw_request_no_dp = MagicMock()
    mock_raw_request_no_dp.headers = {}
    mock_raw_request_no_dp.state = MagicMock()

    with suppress(Exception):
        await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)

    # Verify that data_parallel_rank defaults to None
    assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
    assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None