test_trtllm.py 16.4 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
"""
5
6
7
8
9
Test Execution Times (Last Run: 2026-01-12):
- test_request_migration_trtllm_aggregated: ~95s
- test_request_migration_trtllm_prefill: N/A
- test_request_migration_trtllm_kv_transfer: N/A
- test_request_migration_trtllm_decode: N/A
10
11
"""

12
13
14
15
16
17
18
import logging
import os
import shutil

import pytest

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
19
from tests.utils.managed_process import ManagedProcess
20
from tests.utils.payloads import check_models_api
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
24
# Customized utils for migration tests
from .utils import DynamoFrontendProcess, run_migration_test
25
26
27
28

logger = logging.getLogger(__name__)

pytestmark = [
29
    pytest.mark.fault_tolerance,
30
31
32
33
    pytest.mark.trtllm,
    pytest.mark.gpu_1,
    pytest.mark.e2e,
    pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
34
35
36
    pytest.mark.parametrize(
        "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
    ),
37
38
39
40
41
42
43
44
    pytest.mark.parametrize(
        "migration_max_seq_len",
        [
            pytest.param(None, id="max_seq_len_disabled"),
            pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
            pytest.param(1, id="max_seq_len_exceeded"),
        ],
    ),
45
    pytest.mark.parametrize(
46
        "immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    ),
    pytest.mark.parametrize(
        "request_api",
        [
            pytest.param("chat"),
            pytest.param(
                "completion",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
    pytest.mark.parametrize(
        "stream",
        [
            pytest.param(True, id="stream"),
            pytest.param(
                False,
                id="unary",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
69
    pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
70
71
72
73
]


class DynamoWorkerProcess(ManagedProcess):
74
75
76
77
78
79
80
81
82
83
84
    """Process manager for Dynamo worker with TRT-LLM backend

    Supports both aggregated mode (single worker) and disaggregated mode
    (separate prefill and decode workers).

    Args:
        request: pytest request fixture
        worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
        frontend_port: Port where the frontend is running
        mode: "prefill_and_decode" for aggregated, "prefill" or "decode" for disaggregated
    """
85

86
87
88
89
90
    def __init__(
        self,
        request,
        worker_id: str,
        frontend_port: int,
91
        mode: str = "prefill_and_decode",
92
    ):
93
        self.worker_id = worker_id
94
95
        self.system_port = allocate_port(9100)
        self.mode = mode
96

97
98
99
100
101
102
103
        command = [
            "python3",
            "-m",
            "dynamo.trtllm",
            "--model",
            FAULT_TOLERANCE_MODEL_NAME,
            "--disaggregation-mode",
104
            mode,
105
106
            "--max-seq-len",
            "8192",
107
108
109
110
            "--max-num-tokens",
            "8192",
            "--free-gpu-memory-fraction",
            "0.15",  # avoid validation error on TRT-LLM available memory checks
111
        ]
112
113
114
115
116
117
118
119
120
121
122
        if mode != "prefill_and_decode":
            config_file = (
                f"test_request_migration_trtllm_config_{self.system_port}.yaml"
            )
            with open(config_file, "w") as f:
                f.write(
                    "cache_transceiver_config:\n  backend: DEFAULT\n  max_tokens_in_buffer: 8192\n"
                )
                f.write("disable_overlap_scheduler: true\n")
                f.write("kv_cache_config:\n  max_tokens: 8192\n")
            command += ["--extra-engine-args", config_file]
123

124
        # Set environment variables
125
        env = os.environ.copy()
126
        env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
127

128
        env["DYN_LOG"] = "debug"
129
130
131
132
133
        # Disable canary health check - these tests expect full control over requests
        # sent to the workers where canary health check intermittently sends dummy
        # requests to workers interfering with the test process which may cause
        # intermittent failures
        env["DYN_HEALTH_CHECK_ENABLED"] = "false"
134
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
135
136
137
        env["DYN_SYSTEM_PORT"] = str(self.system_port)
        env["DYN_HTTP_PORT"] = str(frontend_port)

138
139
140
        # Disable backend shutdown grace period for all migration tests
        env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"

141
142
143
144
145
146
147
148
        # Configure health check based on worker type
        health_check_urls = [
            (f"http://localhost:{self.system_port}/health", self.is_ready)
        ]
        if mode in ["decode", "prefill_and_decode"]:
            health_check_urls.append(
                (f"http://localhost:{frontend_port}/v1/models", check_models_api)
            )
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

        # TODO: Have the managed process take a command name explicitly to distinguish
        #       between processes started with the same command.
        log_dir = f"{request.node.name}_{worker_id}"

        # Clean up any existing log directory from previous runs
        try:
            shutil.rmtree(log_dir)
            logger.info(f"Cleaned up existing log directory: {log_dir}")
        except FileNotFoundError:
            # Directory doesn't exist, which is fine
            pass

        super().__init__(
            command=command,
            env=env,
165
            health_check_urls=health_check_urls,
166
167
            timeout=300,
            display_output=True,
168
            terminate_all_matching_process_names=False,
169
            log_dir=log_dir,
170
            display_name=worker_id,
171
172
        )

173
174
175
176
177
178
179
180
181
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Release allocated port when worker exits."""
        try:
            # system_port is always allocated in __init__
            deallocate_port(self.system_port)
        except Exception as e:
            logging.warning(f"Failed to release TRT-LLM worker port: {e}")

        return super().__exit__(exc_type, exc_val, exc_tb)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    def is_ready(self, response) -> bool:
        """Check the health of the worker process"""
        try:
            data = response.json()
            if data.get("status") == "ready":
                logger.info(f"{self.worker_id} status is ready")
                return True
            logger.warning(
                f"{self.worker_id} status is not ready: {data.get('status')}"
            )
        except ValueError:
            logger.warning(f"{self.worker_id} health response is not valid JSON")
        return False


Dmitry Tokarev's avatar
Dmitry Tokarev committed
198
@pytest.mark.timeout(290)  # 3x average
199
@pytest.mark.nightly
200
201
202
203
204
205
def test_request_migration_trtllm_aggregated(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
206
    migration_max_seq_len,
207
208
209
    immediate_kill,
    request_api,
    stream,
210
211
):
    """
212
    End-to-end test for aggregated worker request migration.
213

214
215
216
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
217
        migration_max_seq_len: Max sequence length for migration state tracking
218
219
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
220
221
    """

222
    # Step 1: Start the frontend
223
224
225
226
227
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
228
229
        logger.info("Frontend started successfully")

230
        # Step 2: Start 2 workers
231
        with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
232
233
            logger.info(f"Worker 1 PID: {worker1.get_pid()}")

234
            with DynamoWorkerProcess(
235
236
237
                request,
                "worker2",
                frontend.frontend_port,
238
            ) as worker2:
239
240
                logger.info(f"Worker 2 PID: {worker2.get_pid()}")

241
242
243
244
245
                # Step 3: Run migration test
                run_migration_test(
                    frontend,
                    worker1,
                    worker2,
246
                    receiving_pattern="AggregatedHandler Request ID: ",
247
                    migration_limit=migration_limit,
248
                    migration_max_seq_len=migration_max_seq_len,
249
250
251
                    immediate_kill=immediate_kill,
                    use_chat_completion=(request_api == "chat"),
                    stream=stream,
252
253
                )

254

255
256
257
@pytest.mark.skip(
    reason="Prefill migration not yet supported, XFail eats up CI time due to timeout"
)
258
@pytest.mark.timeout(350)  # 3x average
259
@pytest.mark.nightly
260
261
262
263
264
265
def test_request_migration_trtllm_prefill(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
266
    migration_max_seq_len,
267
268
269
    immediate_kill,
    request_api,
    stream,
270
271
):
    """
272
273
274
275
276
277
278
279
280
    End-to-end test for prefill worker request migration in disaggregated mode.

    Setup: 1 decode worker + 2 prefill workers

    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
281
282
    """

283
    # Step 1: Start the frontend
284
285
286
287
288
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
289
290
        logger.info("Frontend started successfully")

291
292
293
294
295
296
297
298
        # Step 2: Start decode worker first (required for prefill workers to connect)
        with DynamoWorkerProcess(
            request,
            "worker0",
            frontend.frontend_port,
            mode="decode",
        ) as decode_worker:
            logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
299

300
            # Step 3: Start 2 prefill workers
301
            with DynamoWorkerProcess(
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
                request,
                "worker1",
                frontend.frontend_port,
                mode="prefill",
            ) as prefill1:
                logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    mode="prefill",
                ) as prefill2:
                    logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        prefill1,
                        prefill2,
                        receiving_pattern="Prefill Request ID: ",
                        migration_limit=migration_limit,
324
                        migration_max_seq_len=migration_max_seq_len,
325
326
327
328
329
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
                    )
330

331

332
333
@pytest.mark.skip(reason="Decode worker can get stuck downloading kv cache")
@pytest.mark.timeout(350)  # 3x average
334
@pytest.mark.nightly
335
336
337
338
339
340
def test_request_migration_trtllm_kv_transfer(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
341
    migration_max_seq_len,
342
343
344
    immediate_kill,
    request_api,
    stream,
345
346
):
    """
347
    End-to-end test for request migration during KV transfer in disaggregated mode.
348

349
    Setup: 1 prefill worker + 2 decode workers
350

351
352
353
354
355
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
356
357
    """

358
    # Step 1: Start the frontend
359
360
361
362
363
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
364
365
        logger.info("Frontend started successfully")

366
        # Step 2: Start prefill worker first
367
368
        with DynamoWorkerProcess(
            request,
369
            "worker0",
370
            frontend.frontend_port,
371
372
373
            mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
374

375
            # Step 3: Start 2 decode workers
376
377
            with DynamoWorkerProcess(
                request,
378
                "worker1",
379
                frontend.frontend_port,
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="Decode Request ID: ",
                        migration_limit=migration_limit,
399
                        migration_max_seq_len=migration_max_seq_len,
400
401
402
403
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
404
405
406
                    )


407
@pytest.mark.timeout(350)  # 3x average
408
@pytest.mark.nightly
409
410
411
412
413
414
def test_request_migration_trtllm_decode(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
415
    migration_max_seq_len,
416
417
418
    immediate_kill,
    request_api,
    stream,
419
420
):
    """
421
    End-to-end test for decode worker request migration in disaggregated mode.
422

423
    Setup: 1 prefill worker + 2 decode workers
424

425
426
427
428
429
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
430
    """
431
432
433
434
    if not stream:
        pytest.skip(
            "Decode test requires streaming to wait for response before stopping worker"
        )
435

436
    # Step 1: Start the frontend
437
438
439
440
441
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
442
443
        logger.info("Frontend started successfully")

444
        # Step 2: Start prefill worker first
445
446
        with DynamoWorkerProcess(
            request,
447
            "worker0",
448
            frontend.frontend_port,
449
450
451
            mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
452

453
            # Step 3: Start 2 decode workers
454
455
            with DynamoWorkerProcess(
                request,
456
                "worker1",
457
                frontend.frontend_port,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="Decode Request ID: ",
                        migration_limit=migration_limit,
477
                        migration_max_seq_len=migration_max_seq_len,
478
479
480
481
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        wait_for_new_response_before_stop=True,
482
                    )