test_sglang.py 17.3 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
"""
5
6
7
8
9
Test Execution Times (Last Run: 2026-01-13):
- test_request_migration_sglang_aggregated: ~75s
- test_request_migration_sglang_prefill: N/A
- test_request_migration_sglang_kv_transfer: N/A
- test_request_migration_sglang_decode: ~75s
10
11
"""

12
13
14
15
16
17
18
import logging
import os
import shutil

import pytest

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
19
from tests.utils.managed_process import ManagedProcess
20
from tests.utils.payloads import check_models_api
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
24
# Customized utils for migration tests
from .utils import DynamoFrontendProcess, run_migration_test
25
26
27
28

logger = logging.getLogger(__name__)

pytestmark = [
29
    pytest.mark.fault_tolerance,
30
31
32
33
    pytest.mark.sglang,
    pytest.mark.gpu_1,
    pytest.mark.e2e,
    pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
34
35
36
    pytest.mark.parametrize(
        "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
    ),
37
38
39
40
41
42
43
44
    pytest.mark.parametrize(
        "migration_max_seq_len",
        [
            pytest.param(None, id="max_seq_len_disabled"),
            pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
            pytest.param(1, id="max_seq_len_exceeded"),
        ],
    ),
45
46
47
48
    pytest.mark.parametrize(
        "immediate_kill",
        [
            pytest.param(True, id="worker_failure"),
49
            pytest.param(False, id="graceful_shutdown"),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        ],
    ),
    pytest.mark.parametrize(
        "request_api",
        [
            pytest.param("chat"),
            pytest.param(
                "completion",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
    pytest.mark.parametrize(
        "stream",
        [
            pytest.param(True, id="stream"),
            pytest.param(
                False,
                id="unary",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
73
    pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
74
75
76
77
]


class DynamoWorkerProcess(ManagedProcess):
78
79
80
81
82
83
84
85
86
87
88
    """Process manager for Dynamo worker with SGLang backend

    Supports both aggregated mode (single worker) and disaggregated mode
    (separate prefill and decode workers).

    Args:
        request: pytest request fixture
        worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
        frontend_port: Port where the frontend is running
        disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
    """
89

90
91
92
93
94
    def __init__(
        self,
        request,
        worker_id: str,
        frontend_port: int,
95
        disagg_mode: str | None = None,
96
    ):
97
        self.worker_id = worker_id
98
99
100
        self.system_port = allocate_port(9100)
        self.disagg_mode = disagg_mode

101
102
103
104
105
106
107
108
109
        command = [
            "python3",
            "-m",
            "dynamo.sglang",
            "--model-path",
            FAULT_TOLERANCE_MODEL_NAME,
            "--served-model-name",
            FAULT_TOLERANCE_MODEL_NAME,
            "--trust-remote-code",
110
111
112
113
            "--page-size",
            "16",
            "--tp",
            "1",
114
            "--mem-fraction-static",
115
            "0.3",
116
117
118
            "--context-length",
            "8192",
        ]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        if disagg_mode is None:
            # Aggregated
            command.append("--skip-tokenizer-init")
        else:
            # Disaggregated
            command.extend(
                [
                    "--disaggregation-mode",
                    disagg_mode,
                    "--disaggregation-bootstrap-port",
                    f"1234{worker_id[-1]}",
                    "--host",
                    "0.0.0.0",
                    "--disaggregation-transfer-backend",
                    "nixl",
                ]
            )
            if disagg_mode == "prefill":
                command.extend(["--port", "40000"])
138

139
        # Set environment variables
140
        env = os.environ.copy()
141
        env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
142

143
        env["DYN_LOG"] = "debug"
144
145
146
147
148
        # Disable canary health check - these tests expect full control over requests
        # sent to the workers where canary health check intermittently sends dummy
        # requests to workers interfering with the test process which may cause
        # intermittent failures
        env["DYN_HEALTH_CHECK_ENABLED"] = "false"
149
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
150
        env["DYN_SYSTEM_PORT"] = str(self.system_port)
151
        env["DYN_HTTP_PORT"] = str(frontend_port)
152

153
154
155
        # Disable backend shutdown grace period for all migration tests
        env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"

156
157
158
159
160
161
162
163
164
        # Configure health check based on worker type
        health_check_urls = [
            (f"http://localhost:{self.system_port}/health", self.is_ready)
        ]
        if disagg_mode is None or disagg_mode == "decode":
            health_check_urls.append(
                (f"http://localhost:{frontend_port}/v1/models", check_models_api)
            )

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        # TODO: Have the managed process take a command name explicitly to distinguish
        #       between processes started with the same command.
        log_dir = f"{request.node.name}_{worker_id}"

        # Clean up any existing log directory from previous runs
        try:
            shutil.rmtree(log_dir)
            logger.info(f"Cleaned up existing log directory: {log_dir}")
        except FileNotFoundError:
            # Directory doesn't exist, which is fine
            pass

        super().__init__(
            command=command,
            env=env,
180
            health_check_urls=health_check_urls,
181
182
            timeout=300,
            display_output=True,
183
            terminate_all_matching_process_names=False,
184
185
186
187
188
            stragglers=["SGLANG:EngineCore"],
            straggler_commands=["-m dynamo.sglang"],
            log_dir=log_dir,
        )

189
190
191
192
193
194
195
196
197
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Release allocated port when worker exits."""
        try:
            # system_port is a required parameter, always set in __init__
            deallocate_port(self.system_port)
        except Exception as e:
            logging.warning(f"Failed to release SGLang worker port: {e}")

        return super().__exit__(exc_type, exc_val, exc_tb)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    def is_ready(self, response) -> bool:
        """Check the health of the worker process"""
        try:
            data = response.json()
            if data.get("status") == "ready":
                logger.info(f"{self.worker_id} status is ready")
                return True
            logger.warning(
                f"{self.worker_id} status is not ready: {data.get('status')}"
            )
        except ValueError:
            logger.warning(f"{self.worker_id} health response is not valid JSON")
        return False


214
@pytest.mark.timeout(230)  # 3x average
215
@pytest.mark.post_merge
216
217
218
219
220
221
def test_request_migration_sglang_aggregated(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
222
    migration_max_seq_len,
223
224
225
    immediate_kill,
    request_api,
    stream,
226
227
):
    """
228
    End-to-end test for aggregated worker request migration.
229

230
231
232
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
233
        migration_max_seq_len: Max sequence length for migration state tracking
234
235
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
236
237
    """

238
239
240
241
    # OPS-4446: first-token delay routinely exceeds the 6s threshold in
    # utils.validate_response for this parameter combination. Originally only
    # the NATS variant tripped; once the NATS skip landed, the TCP variant
    # started failing the same way (now bears the cold-start cost first).
242
243
244
245
246
247
248
    if (
        migration_limit == 3
        and migration_max_seq_len is None
        and immediate_kill is True
        and request_api == "chat"
        and stream is True
    ):
249
        pytest.skip("Flaky: first-token delay > 6s threshold. OPS-4446")
250

251
    # Step 1: Start the frontend
252
253
254
255
256
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
257
258
        logger.info("Frontend started successfully")

259
        # Step 2: Start 2 workers
260
        with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
261
262
            logger.info(f"Worker 1 PID: {worker1.get_pid()}")

263
264
265
            with DynamoWorkerProcess(
                request,
                "worker2",
266
                frontend.frontend_port,
267
            ) as worker2:
268
269
                logger.info(f"Worker 2 PID: {worker2.get_pid()}")

270
271
272
273
274
275
276
                # Step 3: Run migration test
                run_migration_test(
                    frontend,
                    worker1,
                    worker2,
                    receiving_pattern="New Request ID: ",
                    migration_limit=migration_limit,
277
                    migration_max_seq_len=migration_max_seq_len,
278
279
280
                    immediate_kill=immediate_kill,
                    use_chat_completion=(request_api == "chat"),
                    stream=stream,
281
282
                )

283

284
285
286
@pytest.mark.skip(reason="Cannot reliably migrate at Prefill that finish < 1 ms")
@pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
@pytest.mark.timeout(230)  # 3x average
287
@pytest.mark.nightly
288
289
290
291
292
293
def test_request_migration_sglang_prefill(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
294
    migration_max_seq_len,
295
296
297
    immediate_kill,
    request_api,
    stream,
298
299
):
    """
300
    End-to-end test for prefill worker request migration in disaggregated mode.
301

302
    Setup: 1 decode worker + 2 prefill workers
303

304
305
306
307
308
309
310
311
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
    """

    # Step 1: Start the frontend
312
313
314
315
316
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
317
318
        logger.info("Frontend started successfully")

319
        # Step 2: Start decode worker first (required for prefill workers to connect)
320
321
        with DynamoWorkerProcess(
            request,
322
323
324
325
326
327
328
            "worker0",
            frontend.frontend_port,
            disagg_mode="decode",
        ) as decode_worker:
            logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

            # Step 3: Start 2 prefill workers
329
330
            with DynamoWorkerProcess(
                request,
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                "worker1",
                frontend.frontend_port,
                disagg_mode="prefill",
            ) as prefill1:
                logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="prefill",
                ) as prefill2:
                    logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        prefill1,
                        prefill2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
352
                        migration_max_seq_len=migration_max_seq_len,
353
354
355
356
357
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
                    )
358

359

360
361
@pytest.mark.skip(reason="KV cache transfer may fail")
@pytest.mark.timeout(230)  # 3x average
362
@pytest.mark.nightly
363
364
365
366
367
368
def test_request_migration_sglang_kv_transfer(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
369
    migration_max_seq_len,
370
371
372
    immediate_kill,
    request_api,
    stream,
373
374
):
    """
375
    End-to-end test for request migration during KV transfer in disaggregated mode.
376

377
    Setup: 1 prefill worker + 2 decode workers
378

379
380
381
382
383
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
384
385
    """

386
    # Step 1: Start the frontend
387
388
389
390
391
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
392
393
        logger.info("Frontend started successfully")

394
        # Step 2: Start prefill worker first
395
396
        with DynamoWorkerProcess(
            request,
397
398
399
400
401
402
403
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
404
405
            with DynamoWorkerProcess(
                request,
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
427
                        migration_max_seq_len=migration_max_seq_len,
428
429
430
431
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
432
433
434
                    )


435
@pytest.mark.timeout(230)  # 3x average
436
@pytest.mark.nightly
437
438
439
440
441
442
def test_request_migration_sglang_decode(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
443
    migration_max_seq_len,
444
445
446
    immediate_kill,
    request_api,
    stream,
447
448
):
    """
449
    End-to-end test for decode worker request migration in disaggregated mode.
450

451
    Setup: 1 prefill worker + 2 decode workers
452

453
454
455
456
457
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
458
    """
459
460
461
462
    if not stream:
        pytest.skip(
            "Decode test requires streaming to wait for response before stopping worker"
        )
463

464
    # Step 1: Start the frontend
465
466
467
468
469
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
470
471
        logger.info("Frontend started successfully")

472
        # Step 2: Start prefill worker first
473
474
        with DynamoWorkerProcess(
            request,
475
476
477
478
479
480
481
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
482
483
            with DynamoWorkerProcess(
                request,
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
505
                        migration_max_seq_len=migration_max_seq_len,
506
507
508
509
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        wait_for_new_response_before_stop=True,
510
                    )