test_router_e2e_with_mockers.py 27.2 KB
Newer Older
1
2
3
4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
5
from contextlib import nullcontext
6
from typing import Any, Dict, Optional
7
8
9

import pytest

10
from tests.conftest import EtcdServer, NatsServer
11
from tests.router.common import (  # utilities
12
    _test_busy_threshold_endpoint,
13
14
15
    _test_python_router_bindings,
    _test_router_basic,
    _test_router_decisions,
16
    _test_router_decisions_disagg,
17
18
19
20
21
22
23
    _test_router_indexers_sync,
    _test_router_overload_503,
    _test_router_query_instance_id,
    _test_router_two_routers,
    generate_random_suffix,
    get_runtime,
)
Alec's avatar
Alec committed
24
from tests.utils.constants import ROUTER_MODEL_NAME
25
26
from tests.utils.managed_process import ManagedProcess

27
28
29
30
logger = logging.getLogger(__name__)

MODEL_NAME = ROUTER_MODEL_NAME

31
32
33
34
pytestmark = [
    pytest.mark.pre_merge,
    pytest.mark.gpu_0,
    pytest.mark.integration,
35
    pytest.mark.model(MODEL_NAME),
36
]
37
38
NUM_MOCKERS = 2
SPEEDUP_RATIO = 10.0
39
BASE_PORT = 9100  # Base port for all tests (high port to avoid conflicts)
40
NUM_REQUESTS = 100
41
BLOCK_SIZE = 16
42
43


44
def get_unique_ports(
45
46
47
48
    request,
    num_ports: int = 1,
    store_backend: str = "etcd",
    request_plane: str = "nats",
49
    registration_order: str = "prefill_first",
50
51
52
53
54
) -> list[int]:
    """Generate unique ports for parallel test execution.

    Ports are unique based on:
    - Test function name (each test gets a base offset)
55
    - Parametrization value (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
56
57
58
59
60
61
    - Port index (for multi-port tests)

    Args:
        request: Pytest request fixture
        num_ports: Number of ports needed (1 for single router, 2 for two routers)
        store_backend: Storage backend parameter ("etcd" or "file")
62
        request_plane: Request plane parameter ("nats" or "tcp")
63
        registration_order: Registration order parameter ("prefill_first" or "decode_first")
64
65
66
67
68
69
70
71
72
73
74
75
76

    Returns:
        List of unique port numbers
    """
    # Get test name without parametrization suffix
    test_name = request.node.name.split("[")[0]

    # Base offsets per test function (ensures each test gets unique range)
    test_offsets = {
        "test_mocker_kv_router": 0,
        "test_mocker_two_kv_router": 100,
        "test_mocker_kv_router_overload_503": 200,
        "test_query_instance_id_returns_worker_and_tokens": 300,
77
        "test_router_decisions_disagg": 400,
78
        "test_busy_threshold_endpoint": 500,
79
80
81
82
    }

    base_offset = test_offsets.get(test_name, 0)

83
    # Parametrization offset (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
84
85
    store_offset = 0 if store_backend == "etcd" else 50
    plane_offset = 0 if request_plane == "nats" else 25
86
    order_offset = 0 if registration_order == "prefill_first" else 10
87
88

    # Generate ports
89
    ports = [
90
        BASE_PORT + base_offset + store_offset + plane_offset + order_offset + i
91
92
        for i in range(num_ports)
    ]
93
94
95
    return ports


96
97
98
99
100
101
102
103
104
105
106
107
108
# Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = {
    "model": MODEL_NAME,
    "messages": [
        {
            "role": "user",
            "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
        }
    ],
    "stream": True,
    "max_tokens": 10,
}

109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _build_mocker_command(
    endpoint: str,
    store_backend: str,
    num_workers: int,
    mocker_args: Dict[str, Any],
    worker_type: Optional[str] = None,
) -> list[str]:
    """Build the mocker CLI command with all arguments.

    Args:
        endpoint: The dynamo endpoint string
        store_backend: Storage backend ("etcd" or "file")
        num_workers: Number of workers to spawn (uses --num-workers flag)
        mocker_args: Dictionary of mocker arguments
        worker_type: Optional worker type ("prefill" or "decode") for disagg mode

    Returns:
        List of command arguments for subprocess
    """
    command = [
        "python",
        "-m",
        "dynamo.mocker",
        "--model-path",
        MODEL_NAME,
        "--endpoint",
        endpoint,
        "--store-kv",
        store_backend,
        "--num-workers",
        str(num_workers),
    ]

    # Add worker type flag for disaggregated mode
    if worker_type == "prefill":
        command.append("--is-prefill-worker")
    elif worker_type == "decode":
        command.append("--is-decode-worker")

    # Add individual CLI arguments from mocker_args
    if "speedup_ratio" in mocker_args:
        command.extend(["--speedup-ratio", str(mocker_args["speedup_ratio"])])
    if "block_size" in mocker_args:
        command.extend(["--block-size", str(mocker_args["block_size"])])
    if "num_gpu_blocks" in mocker_args:
        command.extend(
            ["--num-gpu-blocks-override", str(mocker_args["num_gpu_blocks"])]
        )
    if "max_num_seqs" in mocker_args:
        command.extend(["--max-num-seqs", str(mocker_args["max_num_seqs"])])
    if "max_num_batched_tokens" in mocker_args:
        command.extend(
            ["--max-num-batched-tokens", str(mocker_args["max_num_batched_tokens"])]
        )
    if "enable_prefix_caching" in mocker_args:
        if mocker_args["enable_prefix_caching"]:
            command.append("--enable-prefix-caching")
        else:
            command.append("--no-enable-prefix-caching")
    if "enable_chunked_prefill" in mocker_args:
        if mocker_args["enable_chunked_prefill"]:
            command.append("--enable-chunked-prefill")
        else:
            command.append("--no-enable-chunked-prefill")
    if "watermark" in mocker_args:
        command.extend(["--watermark", str(mocker_args["watermark"])])
    if "dp_size" in mocker_args:
        command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
178
179
    if mocker_args.get("enable_local_indexer"):
        command.append("--enable-local-indexer")
180
181
182
183

    return command


184
class MockerProcess:
185
    """Manages mocker engine instances with shared tokio runtime via --num-workers."""
186

187
188
189
190
191
    def __init__(
        self,
        request,
        mocker_args: Optional[Dict[str, Any]] = None,
        num_mockers: int = 1,
192
        store_backend: str = "etcd",
193
        request_plane: str = "nats",
194
    ):
195
196
        namespace_suffix = generate_random_suffix()
        self.namespace = f"test-namespace-{namespace_suffix}"
197
198
        self.component_name = "mocker"
        self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
199
200
201
202
203
204
205
206
207
208
209
        self.num_workers = num_mockers

        mocker_args = mocker_args or {}

        command = _build_mocker_command(
            endpoint=self.endpoint,
            store_backend=store_backend,
            num_workers=num_mockers,
            mocker_args=mocker_args,
        )

210
211
212
        env = os.environ.copy()
        env["DYN_REQUEST_PLANE"] = request_plane

213
214
        self._process = ManagedProcess(
            command=command,
215
            env=env,
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            timeout=60,
            display_output=True,
            health_check_ports=[],
            health_check_urls=[],
            log_dir=request.node.name,
            terminate_existing=False,
        )
        logger.info(
            f"Created mocker process with {num_mockers} worker(s), endpoint: {self.endpoint}"
        )

    def __enter__(self):
        logger.info(f"Starting mocker process with {self.num_workers} worker(s)")
        self._process.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        logger.info("Stopping mocker process")
        self._process.__exit__(exc_type, exc_val, exc_tb)


class DisaggMockerProcess:
    """Manages prefill or decode mocker instances for disaggregated serving.

    Uses --num-workers for shared tokio runtime. For disaggregated serving:
    - Prefill workers: worker_type="prefill", endpoint is namespace.prefill.generate
    - Decode workers: worker_type="decode", endpoint is namespace.backend.generate

    Both prefill and decode workers should share the same namespace for proper discovery.
    """

    def __init__(
        self,
        request,
        namespace: str,
        worker_type: str,
        mocker_args: Optional[Dict[str, Any]] = None,
        num_mockers: int = 1,
        store_backend: str = "etcd",
    ):
        if worker_type not in ("prefill", "decode"):
            raise ValueError(
                f"worker_type must be 'prefill' or 'decode', got {worker_type}"
259
            )
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        self.namespace = namespace
        self.worker_type = worker_type
        self.num_workers = num_mockers

        # Set component name and endpoint based on worker type
        if worker_type == "prefill":
            self.component_name = "prefill"
            self.endpoint = f"dyn://{self.namespace}.prefill.generate"
        else:
            self.component_name = "backend"
            self.endpoint = f"dyn://{self.namespace}.backend.generate"

        mocker_args = mocker_args or {}

        command = _build_mocker_command(
            endpoint=self.endpoint,
            store_backend=store_backend,
            num_workers=num_mockers,
            mocker_args=mocker_args,
            worker_type=worker_type,
        )

        self._process = ManagedProcess(
            command=command,
            timeout=60,
            display_output=True,
            health_check_ports=[],
            health_check_urls=[],
            log_dir=request.node.name,
            terminate_existing=False,
        )
        logger.info(
            f"Created {worker_type} mocker process with {num_mockers} worker(s), "
            f"endpoint: {self.endpoint}"
        )
296

297
    def __enter__(self):
298
299
300
301
        logger.info(
            f"Starting {self.worker_type} mocker process with {self.num_workers} worker(s)"
        )
        self._process.__enter__()
302
        return self
303

304
    def __exit__(self, exc_type, exc_val, exc_tb):
305
306
        logger.info(f"Stopping {self.worker_type} mocker process")
        self._process.__exit__(exc_type, exc_val, exc_tb)
307
308


309
@pytest.mark.parallel
310
def test_mocker_kv_router(request, runtime_services_session, predownload_tokenizers):
311
312
313
314
315
316
317
318
    """
    Test KV router with multiple mocker engine instances.
    This test doesn't require GPUs and runs quickly for pre-merge validation.
    """

    # runtime_services starts etcd and nats
    logger.info("Starting mocker KV router test")

319
    # Create mocker args dictiona: FixtureRequestry: tuple[NatsServer, EtcdServer]: NoneType
320
    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
321
322

    try:
323
        # Start mocker instances with the new CLI interface
324
        logger.info(f"Starting {NUM_MOCKERS} mocker instances")
325
326
327
        mockers = MockerProcess(
            request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
        )
328
329
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
        mockers.__enter__()
330

331
332
333
334
        # Get unique port for this test
        frontend_port = get_unique_ports(request, num_ports=1)[0]

        # Run basic router test (starts router internally and waits for workers to be ready)
335
336
337
338
        _test_router_basic(
            engine_workers=mockers,
            block_size=BLOCK_SIZE,
            request=request,
339
            frontend_port=frontend_port,
340
341
            test_payload=TEST_PAYLOAD,
            num_requests=NUM_REQUESTS,
342
343
344
        )

    finally:
345
346
        if "mockers" in locals():
            mockers.__exit__(None, None, None)
347
348


349
@pytest.mark.parallel
350
351
352
@pytest.mark.parametrize("store_backend", ["etcd", "file"])
def test_mocker_two_kv_router(
    request,
353
    runtime_services_session,
354
355
356
357
    predownload_tokenizers,
    file_storage_backend,
    store_backend,
):
358
359
360
    """
    Test with two KV routers and multiple mocker engine instances.
    Alternates requests between the two routers to test load distribution.
361
    Tests with both etcd and file storage backends.
362
363
364
    """

    # runtime_services starts etcd and nats
365
366
367
    logger.info(
        f"Starting mocker two KV router test with {store_backend} storage backend"
    )
368

369
    # Create mocker args dictionary
370
371
372
    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}

    try:
373
        # Start mocker instances with the new CLI interface
374
        logger.info(f"Starting {NUM_MOCKERS} mocker instances")
375
        mockers = MockerProcess(
376
377
378
379
            request,
            mocker_args=mocker_args,
            num_mockers=NUM_MOCKERS,
            store_backend=store_backend,
380
        )
381
382
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
        mockers.__enter__()
383

384
385
386
387
388
        # Get unique ports for this test (2 ports for two routers)
        router_ports = get_unique_ports(
            request, num_ports=2, store_backend=store_backend
        )

389
390
391
392
393
        # Run two-router test (starts KV routers internally and manages their lifecycle)
        _test_router_two_routers(
            engine_workers=mockers,
            block_size=BLOCK_SIZE,
            request=request,
394
            router_ports=router_ports,
395
396
397
            test_payload=TEST_PAYLOAD,
            num_requests=NUM_REQUESTS,
            store_backend=store_backend,
398
399
400
        )

    finally:
401
402
        if "mockers" in locals():
            mockers.__exit__(None, None, None)
403
404


405
@pytest.mark.parallel
406
@pytest.mark.skip(reason="Flaky, temporarily disabled")
Alec's avatar
Alec committed
407
def test_mocker_kv_router_overload_503(
408
    request, runtime_services_session, predownload_tokenizers
Alec's avatar
Alec committed
409
):
410
    """Test that KV router returns 503 when mocker workers are overloaded."""
411
    logger.info("Starting mocker KV router overload test for 503 status")
412
    # Create mocker args dictionary with limited resources
413
414
415
416
417
    mocker_args = {
        "speedup_ratio": 10,
        "block_size": 4,  # Smaller block size
        "num_gpu_blocks": 64,  # Limited GPU blocks to exhaust quickly
    }
418

419
    try:
420
        # Start single mocker instance with limited resources
421
        logger.info("Starting single mocker instance with limited resources")
422
        mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=1)
423
424
        logger.info(f"Mocker using endpoint: {mockers.endpoint}")
        mockers.__enter__()
425

426
427
428
        # Get unique port for this test
        frontend_port = get_unique_ports(request, num_ports=1)[0]

429
430
431
432
433
434
435
        # Run overload 503 test
        _test_router_overload_503(
            engine_workers=mockers,
            block_size=4,  # Match the mocker's block size
            request=request,
            frontend_port=frontend_port,
            test_payload=TEST_PAYLOAD,
436
            blocks_threshold=0.2,
437
        )
438
439

    finally:
440
441
        if "mockers" in locals():
            mockers.__exit__(None, None, None)
442

443

444
@pytest.mark.parallel
445
446
447
def test_kv_push_router_bindings(
    request, runtime_services_session, predownload_tokenizers
):
448
    """Test KvPushRouter Python bindings with mocker engines."""
449
450
451
452
    logger.info("Starting KvPushRouter bindings test")
    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}

    try:
453
        # Start mocker instances
454
        logger.info(f"Starting {NUM_MOCKERS} mocker instances")
455
456
457
        mockers = MockerProcess(
            request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
        )
458
459
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
        mockers.__enter__()
460

461
462
463
        # Get runtime and create endpoint
        runtime = get_runtime()
        namespace = runtime.namespace(mockers.namespace)
464
        component = namespace.component(mockers.component_name)
465
466
        endpoint = component.endpoint("generate")

467
468
469
        # Run Python router bindings test
        _test_python_router_bindings(
            engine_workers=mockers,
470
471
            endpoint=endpoint,
            block_size=BLOCK_SIZE,
472
473
            model_name=MODEL_NAME,
            num_workers=NUM_MOCKERS,
474
        )
475
476

    finally:
477
478
479
480
        if "mockers" in locals():
            mockers.__exit__(None, None, None)


481
482
483
484
485
486
487
488
489
490
# NO @pytest.mark.parallel - nats_core variant stops/restarts NATS
@pytest.mark.parametrize(
    "store_backend,use_nats_core,request_plane",
    [
        ("etcd", False, "nats"),  # JetStream mode
        # ("etcd", True, "tcp"),  # ignored, needs unconditional nats_client
        ("file", False, "nats"),  # File backend
    ],
    ids=["jetstream", "file"],  # "nats_core" commented out to match commented test case
)
491
492
493
494
495
def test_indexers_sync(
    request,
    predownload_tokenizers,
    file_storage_backend,
    store_backend,
496
497
    use_nats_core,
    request_plane,
498
):
499
500
501
502
    """
    Test that two KV routers have synchronized indexer states after processing requests.
    This test verifies that both routers converge to the same internal state.

503
504
505
506
507
508
509
510
511
512
    Tests with three configurations:
    - jetstream: etcd backend, JetStream for KV events, NATS request plane
    - nats_core: etcd backend, local indexer with NATS Core, TCP request plane
                 (includes NATS interruption/recovery testing)
    - file: file backend, JetStream for KV events, NATS request plane
    """
    logger.info(
        f"Starting indexers sync test: store_backend={store_backend}, "
        f"use_nats_core={use_nats_core}, request_plane={request_plane}"
    )
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    # Start NATS manually (needed for all variants - KV event sync)
    with NatsServer(request) as nats_server:
        # Start etcd if needed
        etcd_ctx = EtcdServer(request) if store_backend == "etcd" else nullcontext()
        with etcd_ctx:
            # Create mocker args dictionary
            mocker_args = {
                "speedup_ratio": SPEEDUP_RATIO,
                "block_size": BLOCK_SIZE,
                "enable_local_indexer": use_nats_core,
            }

            try:
                # Start mocker instances
                logger.info(f"Starting {NUM_MOCKERS} mocker instances")
                mockers = MockerProcess(
                    request,
                    mocker_args=mocker_args,
                    num_mockers=NUM_MOCKERS,
                    store_backend=store_backend,
                    request_plane=request_plane,
                )
                logger.info(f"All mockers using endpoint: {mockers.endpoint}")
                mockers.__enter__()

                # Use the common test implementation (creates its own runtimes for each router)
                # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
                _test_router_indexers_sync(
                    engine_workers=mockers,
                    block_size=BLOCK_SIZE,
                    model_name=MODEL_NAME,
                    num_workers=NUM_MOCKERS,
                    store_backend=store_backend,
                    request_plane=request_plane,
                    test_nats_interruption=use_nats_core,
                    nats_server=nats_server if use_nats_core else None,
                )

                logger.info("Indexers sync test completed successfully")

            finally:
                if "mockers" in locals():
                    mockers.__exit__(None, None, None)
557

558

559
@pytest.mark.parallel
Alec's avatar
Alec committed
560
def test_query_instance_id_returns_worker_and_tokens(
561
    request, runtime_services_session, predownload_tokenizers
Alec's avatar
Alec committed
562
):
563
    """Test query_instance_id annotation with mocker engines."""
564
565
566
567
568
    logger.info("Starting KV router query_instance_id annotation test")
    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
    os.makedirs(request.node.name, exist_ok=True)

    try:
569
        # Start mocker instances
570
        logger.info(f"Starting {NUM_MOCKERS} mocker instances")
571
572
573
        mockers = MockerProcess(
            request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
        )
574
575
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
        mockers.__enter__()
576

577
578
579
        # Get unique port for this test
        frontend_port = get_unique_ports(request, num_ports=1)[0]

580
581
582
583
584
585
586
587
        # Run query_instance_id annotation test
        _test_router_query_instance_id(
            engine_workers=mockers,
            block_size=BLOCK_SIZE,
            request=request,
            frontend_port=frontend_port,
            test_payload=TEST_PAYLOAD,
        )
588
589

    finally:
590
591
        if "mockers" in locals():
            mockers.__exit__(None, None, None)
592
593


594
@pytest.mark.parallel
595
596
597
598
599
600
601
602
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"])
def test_router_decisions(
    request, runtime_services_session, predownload_tokenizers, use_nats_core
):
    """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.

    Parameterized to test both JetStream (default) and NATS Core (local indexer) modes.
    """
603
604

    # runtime_services starts etcd and nats
605
606
607
608
    mode = "NATS Core (local indexer)" if use_nats_core else "JetStream"
    logger.info(
        f"Starting test router prefix reuse and KV events synchronization ({mode})"
    )
609

Yan Ru Pei's avatar
Yan Ru Pei committed
610
611
612
613
614
    # Create mocker args dictionary with dp_size=4
    mocker_args = {
        "speedup_ratio": SPEEDUP_RATIO,
        "block_size": BLOCK_SIZE,
        "dp_size": 4,
615
        "enable_local_indexer": use_nats_core,
Yan Ru Pei's avatar
Yan Ru Pei committed
616
    }
617
618

    try:
Yan Ru Pei's avatar
Yan Ru Pei committed
619
        logger.info(
620
            f"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks), {mode}"
621
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
622
        mockers = MockerProcess(request, mocker_args=mocker_args, num_mockers=2)
623
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
624

625
626
627
628
629
630
631
632
633
634
        # Initialize mockers
        mockers.__enter__()

        # Get runtime and create endpoint
        runtime = get_runtime()
        # Use the namespace from the mockers
        namespace = runtime.namespace(mockers.namespace)
        component = namespace.component("mocker")
        endpoint = component.endpoint("generate")

635
636
        _test_router_decisions(
            mockers, endpoint, MODEL_NAME, request, test_dp_rank=True
637
638
639
640
641
        )

    finally:
        if "mockers" in locals():
            mockers.__exit__(None, None, None)
642
643


644
@pytest.mark.parallel
645
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
646
def test_router_decisions_disagg(
647
    request, runtime_services_session, predownload_tokenizers, registration_order
648
649
650
651
652
):
    """Validate KV cache prefix reuse in disaggregated prefill-decode setup.

    Tests that progressive requests with overlapping prefixes are routed to the
    same prefill worker due to KV cache reuse.
653
654
655
656

    Parameterized to test both registration orders:
    - prefill_first: prefill workers register before decode workers
    - decode_first: decode workers register before prefill workers
657
    """
658
659
660
661
    logger.info(
        f"Starting disaggregated router prefix reuse test "
        f"(registration_order={registration_order})"
    )
662
663
664
665
666
667
668
669
670
671
672
673

    # Generate shared namespace for prefill and decode workers
    namespace_suffix = generate_random_suffix()
    shared_namespace = f"test-namespace-{namespace_suffix}"

    # Create mocker args
    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}

    prefill_workers = None
    decode_workers = None

    try:
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        if registration_order == "prefill_first":
            # Start prefill workers first
            logger.info("Starting 4 prefill mocker instances (first)")
            prefill_workers = DisaggMockerProcess(
                request,
                namespace=shared_namespace,
                worker_type="prefill",
                mocker_args=mocker_args,
                num_mockers=4,
            )
            prefill_workers.__enter__()
            logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")

            # Then start decode workers
            logger.info("Starting 4 decode mocker instances (second)")
            decode_workers = DisaggMockerProcess(
                request,
                namespace=shared_namespace,
                worker_type="decode",
                mocker_args=mocker_args,
                num_mockers=4,
            )
            decode_workers.__enter__()
            logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
        else:
            # Start decode workers first
            logger.info("Starting 4 decode mocker instances (first)")
            decode_workers = DisaggMockerProcess(
                request,
                namespace=shared_namespace,
                worker_type="decode",
                mocker_args=mocker_args,
                num_mockers=4,
            )
            decode_workers.__enter__()
            logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")

            # Then start prefill workers
            logger.info("Starting 4 prefill mocker instances (second)")
            prefill_workers = DisaggMockerProcess(
                request,
                namespace=shared_namespace,
                worker_type="prefill",
                mocker_args=mocker_args,
                num_mockers=4,
            )
            prefill_workers.__enter__()
            logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
722
723

        # Get unique port for this test
724
725
726
        frontend_port = get_unique_ports(
            request, num_ports=1, registration_order=registration_order
        )[0]
727
728

        # Run disagg routing test
729
        _test_router_decisions_disagg(
730
731
732
733
734
735
736
737
738
739
740
741
742
            prefill_workers=prefill_workers,
            decode_workers=decode_workers,
            block_size=BLOCK_SIZE,
            request=request,
            frontend_port=frontend_port,
            test_payload=TEST_PAYLOAD,
        )

    finally:
        if decode_workers is not None:
            decode_workers.__exit__(None, None, None)
        if prefill_workers is not None:
            prefill_workers.__exit__(None, None, None)
743
744


745
@pytest.mark.parallel
746
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
747
def test_busy_threshold_endpoint(
748
    request, runtime_services_session, predownload_tokenizers, request_plane
749
750
751
752
753
754
755
756
757
758
):
    """Test that the /busy_threshold endpoint can be hit and responds correctly.

    TODO: This doesn't actually test any e2e rejection for now. A proper test would:
    1. Set a very low threshold
    2. Send enough requests to exceed the threshold
    3. Verify that subsequent requests are rejected with 503

    For now, this test only verifies the endpoint is accessible and returns valid responses.
    """
759
760
761
    logger.info(
        f"Starting busy_threshold endpoint test with request_plane={request_plane}"
    )
762
763
764
765
766
767

    mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}

    try:
        logger.info(f"Starting {NUM_MOCKERS} mocker instances")
        mockers = MockerProcess(
768
769
770
771
            request,
            mocker_args=mocker_args,
            num_mockers=NUM_MOCKERS,
            request_plane=request_plane,
772
773
774
775
        )
        logger.info(f"All mockers using endpoint: {mockers.endpoint}")
        mockers.__enter__()

776
777
778
        frontend_port = get_unique_ports(
            request, num_ports=1, request_plane=request_plane
        )[0]
779
780
781
782
783
784
785

        _test_busy_threshold_endpoint(
            engine_workers=mockers,
            block_size=BLOCK_SIZE,
            request=request,
            frontend_port=frontend_port,
            test_payload=TEST_PAYLOAD,
786
            request_plane=request_plane,
787
788
789
790
791
        )

    finally:
        if "mockers" in locals():
            mockers.__exit__(None, None, None)