test_trtllm.py 20.5 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
"""
5
Test Execution Times (Last Run: 2025-12-13):
6
- test_request_cancellation_trtllm_aggregated: ~45s (gpu_1)
7
8
9
10
- test_request_cancellation_trtllm_decode_cancel: ~65s (gpu_1)
- test_request_cancellation_trtllm_prefill_cancel: ~65s (gpu_1)
- test_request_cancellation_trtllm_kv_transfer_cancel: ~65s (gpu_1)
- Total: ~240s x2 request planes = ~480s (0:08:00)
11
12
"""

13
14
15
16
17
18
19
20
21
import logging
import os
import shutil
import time

import pytest

from tests.fault_tolerance.cancellation.utils import (
    DynamoFrontendProcess,
22
23
24
    poll_for_pattern,
    read_streaming_responses,
    send_cancellable_request,
25
26
27
28
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
29
from tests.utils.port_utils import allocate_port, deallocate_port
30
31
32

logger = logging.getLogger(__name__)

33
34
35
36
37
pytestmark = [
    pytest.mark.trtllm,
    pytest.mark.gpu_1,
    pytest.mark.e2e,
    pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
38
    pytest.mark.post_merge,  # post_merge to pinpoint failure commit
39
    pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
40
    pytest.mark.xfail(reason="Cancellation is temporarily disabled", strict=True),
41
42
]

43
44
45
46

class DynamoWorkerProcess(ManagedProcess):
    """Process manager for Dynamo worker with TensorRT-LLM backend"""

47
48
49
50
51
52
    def __init__(
        self,
        request,
        frontend_port: int,
        mode: str = "prefill_and_decode",
    ):
53
54
55
56
57
        """
        Initialize TensorRT-LLM worker process.

        Args:
            request: pytest request object
58
            frontend_port: Port for the frontend server
59
60
            mode: One of "prefill_and_decode", "prefill", "decode"
        """
61
62
63
64
        # Allocate system port for this worker
        system_port = allocate_port(9100)
        self.system_port = system_port
        self.frontend_port = frontend_port
65

66
67
68
69
70
71
72
73
74
        command = [
            "python3",
            "-m",
            "dynamo.trtllm",
            "--model",
            FAULT_TOLERANCE_MODEL_NAME,
            "--disaggregation-mode",
            mode,
            "--max-seq-len",
75
76
77
            "16384",
            "--max-num-tokens",
            "16384",
78
79
80
        ]
        if mode != "prefill_and_decode":
            with open("test_request_cancellation_trtllm_config.yaml", "w") as f:
81
82
83
                f.write(
                    "cache_transceiver_config:\n  backend: DEFAULT\n  max_tokens_in_buffer: 16384\n"
                )
84
                f.write("disable_overlap_scheduler: true\n")
85
                f.write("kv_cache_config:\n  max_tokens: 16384\n")
86
87
88
89
90
91
            command += [
                "--extra-engine-args",
                "test_request_cancellation_trtllm_config.yaml",
            ]

        health_check_urls = [
92
93
            (f"http://localhost:{frontend_port}/v1/models", check_models_api),
            (f"http://localhost:{frontend_port}/health", check_health_generate),
94
95
        ]

96
97
98
99
100
        # Set health check based on worker type
        if mode in ["prefill", "decode"]:
            health_check_urls = [
                (f"http://localhost:{system_port}/health", self.is_ready)
            ]
101

102
        # Set environment variables
103
        env = os.environ.copy()
104
105
        env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")

106
        env["DYN_LOG"] = "debug"
107
108
109
110
111
        # Disable canary health check - these tests expect full control over requests
        # sent to the workers where canary health check intermittently sends dummy
        # requests to workers interfering with the test process which may cause
        # intermittent failures
        env["DYN_HEALTH_CHECK_ENABLED"] = "false"
112
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
113
        env["DYN_SYSTEM_PORT"] = str(system_port)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        # Set log directory based on worker type
        log_dir = f"{request.node.name}_{mode}_worker"

        # Clean up any existing log directory from previous runs
        try:
            shutil.rmtree(log_dir)
            logger.info(f"Cleaned up existing log directory: {log_dir}")
        except FileNotFoundError:
            # Directory doesn't exist, which is fine
            pass

        super().__init__(
            command=command,
            env=env,
            health_check_urls=health_check_urls,
            timeout=300,
            display_output=True,
132
            terminate_all_matching_process_names=False,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            log_dir=log_dir,
        )

        self.mode = mode

    def is_ready(self, response) -> bool:
        """Check the health of the worker process"""
        try:
            data = response.json()
            if data.get("status") == "ready":
                logger.info(f"{self.mode.capitalize()} worker status is ready")
                return True
            logger.warning(
                f"{self.mode.capitalize()} worker status is not ready: {data.get('status')}"
            )
        except ValueError:
            logger.warning(
                f"{self.mode.capitalize()} worker health response is not valid JSON"
            )
        return False

154
155
156
157
158
159
160
161
162
163
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Release allocated port when worker exits."""
        try:
            # system_port is always allocated in __init__
            deallocate_port(self.system_port)
        except Exception as e:
            logging.warning(f"Failed to release TRT-LLM worker port: {e}")

        return super().__exit__(exc_type, exc_val, exc_tb)

164

165
@pytest.mark.timeout(135)  # 3x average
166
def test_request_cancellation_trtllm_aggregated(
167
    request, runtime_services_dynamic_ports, predownload_models
168
):
169
170
171
172
173
    """
    End-to-end test for request cancellation functionality in aggregated mode.

    This test verifies that when a request is cancelled by the client,
    the system properly handles the cancellation and cleans up resources
174
175
176
177
178
179
180
181
182
    on the worker side in aggregated (prefill_and_decode) mode. Tests three scenarios:
    1. Completion request
    2. Chat completion request (non-streaming)
    3. Chat completion request (streaming)

    Timing (Last Run: 2025-12-09): ~45s total
    - Engine initialization: ~27s (frontend + worker)
    - Testing 3 scenarios: ~15s (~5s each)
    - Teardown: ~3s
183
184
    """

185
    # Step 1: Start the frontend (allocates its own frontend_port)
186
187
188
189
    with DynamoFrontendProcess(request) as frontend:
        logger.info("Frontend started successfully")

        # Step 2: Start an aggregated worker
190
191
192
193
        # Step 2: Start a single worker (allocates its own system_port)
        with DynamoWorkerProcess(
            request, frontend.frontend_port, mode="prefill_and_decode"
        ) as worker:
194
195
196
197
198
            logger.info(f"Aggregated Worker PID: {worker.get_pid()}")

            # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
            time.sleep(2)

199
            # Step 3: Test request cancellation with polling approach
200
201
202
203
204
205
206
207
208
209
210
            frontend_log_offset, worker_log_offset = 0, 0

            test_scenarios = [
                ("completion", "Completion request cancellation"),
                ("chat_completion", "Chat completion request cancellation"),
                (
                    "chat_completion_stream",
                    "Chat completion stream request cancellation",
                ),
            ]

211
            for request_type, description in test_scenarios:
212
213
                logger.info(f"Testing {description.lower()}...")

214
                # Send the request (non-blocking)
215
216
217
                cancellable_req = send_cancellable_request(
                    frontend.frontend_port, request_type
                )
218
219
220
221
222
223
224

                # Poll for "New Request ID" pattern
                request_id, worker_log_offset = poll_for_pattern(
                    process=worker,
                    pattern="New Request ID: ",
                    log_offset=worker_log_offset,
                    match_type="contains",
225
                )
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

                # For streaming, read 5 responses before cancelling
                if request_type == "chat_completion_stream":
                    read_streaming_responses(cancellable_req, expected_count=5)

                # Now cancel the request
                cancellable_req.cancel()
                logger.info(f"Cancelled request ID: {request_id}")

                # Poll for "Aborted Request ID" with matching ID
                _, worker_log_offset = poll_for_pattern(
                    process=worker,
                    pattern=f"Aborted Request ID: {request_id}",
                    log_offset=worker_log_offset,
                )

                # Verify frontend log has kill message
                _, frontend_log_offset = poll_for_pattern(
                    process=frontend,
                    pattern="issued control message Kill to sender",
                    log_offset=frontend_log_offset,
247
248
249
250
251
                )

                logger.info(f"{description} detected successfully")


252
@pytest.mark.timeout(195)  # 3x average
253
def test_request_cancellation_trtllm_decode_cancel(
254
    request, runtime_services_dynamic_ports, predownload_models
255
):
256
    """
257
    End-to-end test for request cancellation during decode phase with unified frontend.
258
259
260

    This test verifies that when a request is cancelled by the client during the decode phase,
    the system properly handles the cancellation and cleans up resources
261
    on the decode worker side in a disaggregated setup.
262
263
264
265
266

    Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each)
    - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential)
    - Testing stream cancellation during decode: ~20s
    - Teardown: ~3s
267
268
    """

269
    # Step 1: Start the frontend (allocates its own frontend_port)
270
271
272
    with DynamoFrontendProcess(request) as frontend:
        logger.info("Frontend started successfully")

273
274
275
276
        # Step 2: Start the prefill worker (allocates its own system_port)
        with DynamoWorkerProcess(
            request, frontend.frontend_port, mode="prefill"
        ) as prefill_worker:
277
278
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

279
280
281
282
            # Step 3: Start the decode worker (allocates its own system_port)
            with DynamoWorkerProcess(
                request, frontend.frontend_port, mode="decode"
            ) as decode_worker:
283
284
285
286
287
                logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

                # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
                time.sleep(2)

288
                # Step 4: Test request cancellation for streaming scenario
289
                logger.info(
290
                    "Testing chat completion stream request cancellation in decode worker (decode phase)..."
291
292
                )

293
                # Send streaming request (non-blocking)
294
295
296
                cancellable_req = send_cancellable_request(
                    frontend.frontend_port, "chat_completion_stream"
                )
297

298
299
300
301
                # Poll for "Prefill Request ID" pattern in prefill worker (frontend routes here first)
                request_id, prefill_log_offset = poll_for_pattern(
                    process=prefill_worker,
                    pattern="Prefill Request ID: ",
302
303
304
                    match_type="contains",
                )

305
306
307
308
                # Verify same request ID reached decode worker (after prefill completes)
                _, decode_log_offset = poll_for_pattern(
                    process=decode_worker,
                    pattern=f"Decode Request ID: {request_id}",
309
                )
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

                # Read 5 streaming responses (decode phase)
                read_streaming_responses(cancellable_req, expected_count=5)

                # Now cancel the request
                cancellable_req.cancel()
                logger.info(f"Cancelled request ID: {request_id}")

                # Poll for "Aborted Request ID" in decode worker
                _, decode_log_offset = poll_for_pattern(
                    process=decode_worker,
                    pattern=f"Aborted Request ID: {request_id}",
                    log_offset=decode_log_offset,
                )

                # Verify frontend log has kill message
                _, frontend_log_offset = poll_for_pattern(
                    process=frontend,
                    pattern="issued control message Kill to sender",
                )

                logger.info(
                    "Chat completion stream cancellation in decode phase detected successfully"
333
334
335
                )


336
@pytest.mark.timeout(195)  # 3x average
337
def test_request_cancellation_trtllm_prefill_cancel(
338
    request, runtime_services_dynamic_ports, predownload_models
339
):
340
    """
341
    End-to-end test for request cancellation during prefill phase with unified frontend.
342

343
344
345
    This test verifies that when a request is cancelled by the client during the prefill phase,
    the system properly handles the cancellation and cleans up resources on the prefill worker.
    Since the request is cancelled before prefill completes, the decode worker never receives it.
346
347
348
349
350

    Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each)
    - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential)
    - Testing cancellation during prefill: ~20s
    - Teardown: ~3s
351
352
    """

353
    # Step 1: Start the frontend (allocates its own frontend_port)
354
355
356
    with DynamoFrontendProcess(request) as frontend:
        logger.info("Frontend started successfully")

357
358
359
360
        # Step 2: Start the prefill worker (allocates its own system_port)
        with DynamoWorkerProcess(
            request, frontend.frontend_port, mode="prefill"
        ) as prefill_worker:
361
362
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

363
364
365
366
            # Step 3: Start the decode worker (allocates its own system_port)
            with DynamoWorkerProcess(
                request, frontend.frontend_port, mode="decode"
            ) as decode_worker:
367
368
369
370
371
372
373
374
375
376
                logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

                # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
                time.sleep(2)

                # Step 4: Test request cancellation during prefill phase
                logger.info(
                    "Testing completion request cancellation during prefill phase..."
                )

377
378
                # Send request with long prompt (non-blocking)
                cancellable_req = send_cancellable_request(
379
                    frontend.frontend_port, "completion", use_long_prompt=True
380
381
                )

382
                # Poll for "Prefill Request ID" pattern in prefill worker (frontend routes here first)
383
384
                request_id, prefill_log_offset = poll_for_pattern(
                    process=prefill_worker,
385
                    pattern="Prefill Request ID: ",
386
387
388
                    match_type="contains",
                )

389
                # Cancel during prefill phase
390
                cancellable_req.cancel()
391
                logger.info(f"Cancelled request ID: {request_id} during prefill")
392

393
                # Poll for "Aborted Request ID" in prefill worker (where cancellation happens)
394
395
396
397
398
399
400
401
402
403
                _, prefill_log_offset = poll_for_pattern(
                    process=prefill_worker,
                    pattern=f"Aborted Request ID: {request_id}",
                    log_offset=prefill_log_offset,
                )

                # Verify frontend log has kill message
                _, frontend_log_offset = poll_for_pattern(
                    process=frontend,
                    pattern="issued control message Kill to sender",
404
                )
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
                # Verify decode worker never received the request
                pattern = "Request ID: "
                try:
                    _, decode_log_offset = poll_for_pattern(
                        process=decode_worker,
                        pattern=pattern,
                        max_wait_ms=10,
                        match_type="contains",
                    )
                    pytest.fail(
                        "Decode worker received request cancelled during prefill phase"
                    )
                except AssertionError as e:
                    assert str(e).startswith(
                        f"Failed to find '{pattern}' pattern after 2 iterations "
                    ), f"Unexpected error: {e}"

423
424
                logger.info(
                    "Completion request cancellation during prefill phase detected successfully"
425
                )
426
427


428
429
@pytest.mark.xfail(reason="Test fails only on CI", strict=False)
@pytest.mark.timeout(195)  # 3x average
430
def test_request_cancellation_trtllm_kv_transfer_cancel(
431
    request, runtime_services_dynamic_ports, predownload_models
432
):
433
434
435
436
437
    """
    End-to-end test for request cancellation during prefill to decode KV transfer phase.

    This test verifies that when a request is cancelled by the client during the KV transfer phase,
    the system properly handles the cancellation and cleans up resources on the workers.
438
439
440
441
442

    Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each)
    - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential)
    - Testing KV transfer cancellation: ~20s
    - Teardown: ~3s
443
444
    """

445
    # Step 1: Start the frontend (allocates its own frontend_port)
446
447
448
    with DynamoFrontendProcess(request) as frontend:
        logger.info("Frontend started successfully")

449
450
451
452
        # Step 2: Start the prefill worker (allocates its own system_port)
        with DynamoWorkerProcess(
            request, frontend.frontend_port, mode="prefill"
        ) as prefill_worker:
453
454
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

455
456
457
458
            # Step 3: Start the decode worker (allocates its own system_port)
            with DynamoWorkerProcess(
                request, frontend.frontend_port, mode="decode"
            ) as decode_worker:
459
460
461
462
463
464
465
466
467
468
469
470
                logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

                # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
                time.sleep(2)

                # Step 4: Test request cancellation during KV transfer phase
                logger.info(
                    "Testing completion request cancellation during KV transfer phase..."
                )

                # Send request with long prompt
                cancellable_req = send_cancellable_request(
471
                    frontend.frontend_port, "completion", use_long_prompt=True
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
                )

                # Poll for "Prefill Request ID" pattern in prefill worker
                request_id, prefill_log_offset = poll_for_pattern(
                    process=prefill_worker,
                    pattern="Prefill Request ID: ",
                    match_type="contains",
                )

                # Poll for decode worker entry signaling start of KV transfer phase
                _, decode_log_offset = poll_for_pattern(
                    process=decode_worker,
                    pattern=f"Decode Request ID: {request_id}",
                    poll_interval_ms=2,
                )

                # Cancel during KV transfer phase in decode worker
                cancellable_req.cancel()
                logger.info(
                    f"Cancelled request ID: {request_id} at beginning of decode"
                )

                # Poll for "Aborted Request ID" in decode worker
                _, decode_log_offset = poll_for_pattern(
                    process=decode_worker,
                    pattern=f"Aborted Request ID: {request_id}",
                    log_offset=decode_log_offset,
                )

                # Verify frontend log has kill message
                _, frontend_log_offset = poll_for_pattern(
                    process=frontend,
                    pattern="issued control message Kill to sender",
                )

                logger.info(
                    "Completion request cancellation at beginning of decode detected successfully"
                )

                # Verify the workers are still functional
512
513
514
                cancellable_req = send_cancellable_request(
                    frontend.frontend_port, "chat_completion_stream"
                )
515
516
517
518
519
520
521
522
523
524
525
                _, decode_log_offset = poll_for_pattern(
                    process=decode_worker,
                    pattern="Decode Request ID: ",
                    log_offset=decode_log_offset,
                    match_type="contains",
                )
                read_streaming_responses(cancellable_req, expected_count=5)

                logger.info(
                    "Workers are functional after cancellation during KV transfer"
                )