Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Duration;
use dashmap::DashMap; use dashmap::DashMap;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_kv_router::config::{KvRouterConfig, RouterPrefillLoadModel};
use tokio::sync::{Notify, Semaphore, mpsc}; use tokio::sync::{Notify, Semaphore, mpsc};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tokio::time::Instant; use tokio::time::Instant;
...@@ -14,8 +17,8 @@ use uuid::Uuid; ...@@ -14,8 +17,8 @@ use uuid::Uuid;
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs, SglangArgs}; use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs, SglangArgs};
use crate::loadgen::{SessionTrace, Trace, TurnTrace}; use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::ReplayRouterMode; use crate::replay::ReplayRouterMode;
use crate::replay::router::ReplayRouter;
use super::ReplayRouter;
use super::entrypoints::{ use super::entrypoints::{
simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats, simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats,
simulate_trace_requests, simulate_trace_requests_with_stats, simulate_trace_requests, simulate_trace_requests_with_stats,
...@@ -55,6 +58,21 @@ fn request(uuid: u128, token: u32, arrival_timestamp_ms: Option<f64>) -> DirectR ...@@ -55,6 +58,21 @@ fn request(uuid: u128, token: u32, arrival_timestamp_ms: Option<f64>) -> DirectR
} }
} }
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
fn multiturn_trace() -> Trace { fn multiturn_trace() -> Trace {
Trace { Trace {
block_size: 1, block_size: 1,
...@@ -96,9 +114,16 @@ fn test_online_trace_replay_single_worker_completes() { ...@@ -96,9 +114,16 @@ fn test_online_trace_replay_single_worker_completes() {
let args = replay_args(); let args = replay_args();
let requests = vec![request(1, 11, Some(0.0)), request(2, 22, Some(1.0))]; let requests = vec![request(1, 11, Some(0.0)), request(2, 22, Some(1.0))];
let report = let report = simulate_trace_requests(
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin) args,
.unwrap(); None,
None,
requests,
1,
1.0,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.num_requests, 2); assert_eq!(report.request_counts.num_requests, 2);
assert_eq!(report.request_counts.completed_requests, 2); assert_eq!(report.request_counts.completed_requests, 2);
...@@ -165,6 +190,7 @@ async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() { ...@@ -165,6 +190,7 @@ async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() {
ReplayRouterMode::KvRouter, ReplayRouterMode::KvRouter,
&args, &args,
None, None,
None,
1, 1,
)); ));
let senders: Arc<[mpsc::UnboundedSender<DirectRequest>]> = let senders: Arc<[mpsc::UnboundedSender<DirectRequest>]> =
...@@ -218,6 +244,50 @@ async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() { ...@@ -218,6 +244,50 @@ async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() {
router.shutdown().await.unwrap(); router.shutdown().await.unwrap();
} }
#[tokio::test(start_paused = true)]
async fn test_online_kv_router_prefill_load_estimator_decays_active_tokens() {
let args = replay_args();
let router = ReplayRouter::new(
ReplayRouterMode::KvRouter,
&args,
Some(KvRouterConfig {
router_track_prefill_tokens: true,
router_prefill_load_model: RouterPrefillLoadModel::Aic,
..KvRouterConfig::default()
}),
Some(Arc::new(FixedPrefillLoadEstimator {
duration: Duration::from_secs(10),
})),
1,
);
assert_eq!(
router
.select_worker(&request(1, 11, Some(0.0)), 1)
.await
.unwrap(),
0
);
assert_eq!(
router.debug_potential_loads(0, true)[0].potential_prefill_tokens,
64
);
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(
router.debug_potential_loads(0, true)[0].potential_prefill_tokens,
32
);
tokio::time::advance(Duration::from_secs(5)).await;
assert_eq!(
router.debug_potential_loads(0, true)[0].potential_prefill_tokens,
0
);
router.shutdown().await.unwrap();
}
#[tokio::test] #[tokio::test]
async fn test_workload_wakeup_is_not_lost_when_completion_happens_before_await() { async fn test_workload_wakeup_is_not_lost_when_completion_happens_before_await() {
let mut driver = Trace { let mut driver = Trace {
...@@ -369,9 +439,16 @@ fn test_online_trace_replay_populates_admit_reuse_stats() { ...@@ -369,9 +439,16 @@ fn test_online_trace_replay_populates_admit_reuse_stats() {
let args = replay_args(); let args = replay_args();
let requests = vec![request(1, 77, Some(0.0)), request(2, 77, Some(5.0))]; let requests = vec![request(1, 77, Some(0.0)), request(2, 77, Some(5.0))];
let report = let report = simulate_trace_requests(
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin) args,
.unwrap(); None,
None,
requests,
1,
1.0,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2); assert_eq!(report.request_counts.completed_requests, 2);
assert!(report.prefix_cache_reused_ratio > 0.0); assert!(report.prefix_cache_reused_ratio > 0.0);
...@@ -395,9 +472,16 @@ fn test_online_trace_replay_sglang_single_worker_completes() { ...@@ -395,9 +472,16 @@ fn test_online_trace_replay_sglang_single_worker_completes() {
let args = sglang_replay_args(); let args = sglang_replay_args();
let requests = vec![request(101, 7, Some(0.0)), request(102, 8, Some(1.0))]; let requests = vec![request(101, 7, Some(0.0)), request(102, 8, Some(1.0))];
let report = let report = simulate_trace_requests(
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin) args,
.unwrap(); None,
None,
requests,
1,
1.0,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2); assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4); assert_eq!(report.request_counts.total_output_tokens, 4);
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod offline;
mod online;
mod shared;
pub(crate) use offline::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use offline::OfflineRouterSnapshot;
pub(crate) use online::ReplayRouter;
...@@ -175,6 +175,66 @@ mod core_behavior { ...@@ -175,6 +175,66 @@ mod core_behavior {
); );
} }
#[test]
fn test_execute_pass_batches_two_ready_requests_together() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(101);
let r2 = Uuid::from_u128(202);
for (uuid, tokens) in [(r1, vec![1; 4]), (r2, vec![2; 4])] {
core.receive(DirectRequest {
tokens,
max_output_tokens: 1,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
collector.on_arrival(r1, 0.0, 4, 1);
collector.on_arrival(r2, 0.0, 4, 1);
let pass = core.execute_pass(&mut collector, 0.0);
let admitted = pass
.admissions
.iter()
.map(|admission| admission.uuid)
.collect::<Vec<_>>();
let first = collector.snapshot(r1).unwrap();
let second = collector.snapshot(r2).unwrap();
assert_eq!(pass.admissions.len(), 2);
assert!(admitted.contains(&r1));
assert!(admitted.contains(&r2));
assert!(
first.first_admit_ms.is_some(),
"r1 should have been admitted"
);
assert!(
second.first_admit_ms.is_some(),
"r2 should have been admitted"
);
assert!(
first.first_token_ms.is_some(),
"r1 should have emitted a token"
);
assert!(
second.first_token_ms.is_some(),
"r2 should have emitted a token"
);
assert_eq!(first.first_admit_ms, second.first_admit_ms);
assert_eq!(first.first_token_ms, second.first_token_ms);
}
#[test] #[test]
fn test_prefill_completion_emits_handoff_delay() { fn test_prefill_completion_emits_handoff_delay() {
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
......
...@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional
import aiohttp import aiohttp
import nats import nats
from dynamo.llm import KvRouter, KvRouterConfig from dynamo.llm import AicPerfConfig, KvRouter, KvRouterConfig
from tests.router.helper import ( from tests.router.helper import (
_nats_server, _nats_server,
assert_event_dumps_equal, assert_event_dumps_equal,
...@@ -1260,6 +1260,7 @@ def _test_router_decisions_disagg( ...@@ -1260,6 +1260,7 @@ def _test_router_decisions_disagg(
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats", request_plane: str = "nats",
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_aic_config: Optional[dict[str, Any]] = None,
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend. """Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
...@@ -1282,6 +1283,7 @@ def _test_router_decisions_disagg( ...@@ -1282,6 +1283,7 @@ def _test_router_decisions_disagg(
test_payload: Base test payload to send to /v1/chat/completions test_payload: Base test payload to send to /v1/chat/completions
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False. durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
router_aic_config: Optional AIC router perf-model config for frontend KV routing.
Raises: Raises:
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure) AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
...@@ -1297,6 +1299,7 @@ def _test_router_decisions_disagg( ...@@ -1297,6 +1299,7 @@ def _test_router_decisions_disagg(
request_plane=request_plane, request_plane=request_plane,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
min_initial_workers=decode_workers.num_workers, min_initial_workers=decode_workers.num_workers,
router_aic_config=router_aic_config,
): ):
# Start KV router frontend - uses decode_workers namespace for discovery # Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers # The frontend will auto-discover both prefill and decode workers
...@@ -1479,6 +1482,7 @@ def _test_router_decisions( ...@@ -1479,6 +1482,7 @@ def _test_router_decisions(
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_event_threads: int = 4, router_event_threads: int = 4,
standalone_indexer_url: Optional[str] = None, standalone_indexer_url: Optional[str] = None,
router_aic_config: Optional[dict[str, Any]] = None,
): ):
"""Validate cross-worker routing decisions based on longest prefix match and tree-size tiebreaking. """Validate cross-worker routing decisions based on longest prefix match and tree-size tiebreaking.
...@@ -1503,6 +1507,7 @@ def _test_router_decisions( ...@@ -1503,6 +1507,7 @@ def _test_router_decisions(
use_kv_events: If True (default), uses KV events from workers. If False, uses use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode). approximate routing with TTL-based expiration (--no-kv-events mode).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False. durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
router_aic_config: Optional AIC router perf-model config for direct KvRouter tests.
Raises: Raises:
AssertionError: If routing decisions don't match expected prefix/tiebreak logic AssertionError: If routing decisions don't match expected prefix/tiebreak logic
...@@ -1524,12 +1529,22 @@ def _test_router_decisions( ...@@ -1524,12 +1529,22 @@ def _test_router_decisions(
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
router_track_prefill_tokens=True,
router_prefill_load_model=(
"aic" if router_aic_config is not None else "none"
),
)
aic_perf_config = (
AicPerfConfig(**router_aic_config)
if router_aic_config is not None
else None
) )
with min_initial_workers_env(expected_num_instances): with min_initial_workers_env(expected_num_instances):
kv_router = KvRouter( kv_router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=block_size, block_size=block_size,
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
aic_perf_config=aic_perf_config,
) )
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
......
...@@ -9,9 +9,10 @@ from typing import Any ...@@ -9,9 +9,10 @@ from typing import Any
from tests.router.common import ( from tests.router.common import (
_test_router_basic, _test_router_basic,
_test_router_decisions, _test_router_decisions,
_test_router_decisions_disagg,
_test_router_indexers_sync, _test_router_indexers_sync,
) )
from tests.router.helper import get_runtime from tests.router.helper import generate_random_suffix, get_runtime
from tests.utils.constants import DefaultPort from tests.utils.constants import DefaultPort
from tests.utils.port_utils import allocate_ports, deallocate_ports from tests.utils.port_utils import allocate_ports, deallocate_ports
from tests.utils.test_output import resolve_test_output_path from tests.utils.test_output import resolve_test_output_path
...@@ -229,6 +230,57 @@ def run_router_decisions_test( ...@@ -229,6 +230,57 @@ def run_router_decisions_test(
) )
def run_disagg_router_decisions_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
request,
request_plane: str,
model_name: str,
block_size: int,
num_prefill_workers: int,
num_decode_workers: int,
prefill_process_kwargs: dict[str, Any] | None = None,
decode_process_kwargs: dict[str, Any] | None = None,
):
shared_namespace = f"test-namespace-{generate_random_suffix()}"
frontend_port = allocate_frontend_ports(request, 1)[0]
prefill_kwargs = {
"namespace": shared_namespace,
**(prefill_process_kwargs or {}),
}
decode_kwargs = {
"namespace": shared_namespace,
**(decode_process_kwargs or {}),
}
with engine_process_cls(
request,
num_workers=num_prefill_workers,
request_plane=request_plane,
**{engine_args_name: engine_args},
**prefill_kwargs,
) as prefill_workers:
with engine_process_cls(
request,
num_workers=num_decode_workers,
request_plane=request_plane,
**{engine_args_name: engine_args},
**decode_kwargs,
) as decode_workers:
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=block_size,
request=request,
frontend_port=frontend_port,
test_payload=build_test_payload(model_name),
request_plane=request_plane,
)
def run_indexers_sync_test( def run_indexers_sync_test(
*, *,
engine_process_cls, engine_process_cls,
......
...@@ -29,6 +29,7 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -29,6 +29,7 @@ class FrontendRouterProcess(ManagedProcess):
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_mode: str = "kv", router_mode: str = "kv",
min_initial_workers: int | None = None, min_initial_workers: int | None = None,
router_aic_config: dict[str, str | int] | None = None,
): ):
command = [ command = [
"python3", "python3",
...@@ -64,6 +65,30 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -64,6 +65,30 @@ class FrontendRouterProcess(ManagedProcess):
if durable_kv_events: if durable_kv_events:
command.append("--router-durable-kv-events") command.append("--router-durable-kv-events")
if router_aic_config is not None:
command.extend(
[
"--router-track-prefill-tokens",
"--router-prefill-load-model",
"aic",
"--aic-backend",
str(router_aic_config["aic_backend"]),
"--aic-system",
str(router_aic_config["aic_system"]),
"--aic-model-path",
str(router_aic_config["aic_model_path"]),
"--aic-tp-size",
str(router_aic_config.get("aic_tp_size", 1)),
]
)
if "aic_backend_version" in router_aic_config:
command.extend(
[
"--aic-backend-version",
str(router_aic_config["aic_backend_version"]),
]
)
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane env["DYN_REQUEST_PLANE"] = request_plane
if min_initial_workers is not None: if min_initial_workers is not None:
......
...@@ -64,6 +64,20 @@ PLANNER_PROFILE_DATA_DIR = ( ...@@ -64,6 +64,20 @@ PLANNER_PROFILE_DATA_DIR = (
Path(__file__).resolve().parents[2] Path(__file__).resolve().parents[2]
/ "components/src/dynamo/planner/tests/data/profiling_results/H200_TP1P_TP1D" / "components/src/dynamo/planner/tests/data/profiling_results/H200_TP1P_TP1D"
) )
ROUTER_AIC_CONFIG = {
"aic_backend": "vllm",
"aic_system": "h200_sxm",
"aic_backend_version": "0.12.0",
"aic_tp_size": 1,
"aic_model_path": "Qwen/Qwen3-32B",
}
def _require_router_aic() -> dict[str, Any]:
pytest.importorskip(
"aiconfigurator", reason="router AIC test requires aiconfigurator"
)
return ROUTER_AIC_CONFIG.copy()
def get_unique_ports( def get_unique_ports(
...@@ -1070,6 +1084,47 @@ def test_router_decisions( ...@@ -1070,6 +1084,47 @@ def test_router_decisions(
) )
@pytest.mark.timeout(300)
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
def test_router_decisions_router_aic(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
request_plane,
):
"""Validate agg KV-router decisions with router-side AIC enabled on the NATS Core path."""
logger.info("Starting agg router decisions test with router-side AIC enabled")
router_aic_config = _require_router_aic()
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": 8,
"dp_size": 4,
"durable_kv_events": False,
}
with MockerProcess(
request,
mocker_args=mocker_args,
num_mockers=2,
request_plane=request_plane,
model_name=MODEL_NAME,
) as mockers:
runtime = get_runtime(request_plane=request_plane)
endpoint = runtime.endpoint(f"{mockers.namespace}.mocker.generate")
_test_router_decisions(
mockers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
use_kv_events=True,
durable_kv_events=False,
router_aic_config=router_aic_config,
)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"]) @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"] "enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"]
...@@ -1194,6 +1249,60 @@ def test_router_decisions_disagg( ...@@ -1194,6 +1249,60 @@ def test_router_decisions_disagg(
) )
@pytest.mark.timeout(180)
def test_router_decisions_disagg_router_aic(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
):
"""Validate disagg KV-router decisions with router-side AIC enabled on the default startup path."""
logger.info("Starting disaggregated router prefix reuse test with router-side AIC")
router_aic_config = _require_router_aic()
namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}"
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
}
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
enable_bootstrap=False,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
frontend_port = get_unique_ports(
request, num_ports=1, registration_order="prefill_first"
)[0]
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats",
router_aic_config=router_aic_config,
)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"durable_kv_events", [False], ids=["nondurable"], indirect=True "durable_kv_events", [False], ids=["nondurable"], indirect=True
......
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
from tests.router.e2e_harness import ( from tests.router.e2e_harness import (
ManagedEngineProcessMixin, ManagedEngineProcessMixin,
run_basic_router_test, run_basic_router_test,
run_disagg_router_decisions_test,
run_indexers_sync_test, run_indexers_sync_test,
run_router_decisions_test, run_router_decisions_test,
) )
...@@ -65,6 +66,9 @@ class SGLangProcess(ManagedEngineProcessMixin): ...@@ -65,6 +66,9 @@ class SGLangProcess(ManagedEngineProcessMixin):
request_plane: str = "tcp", request_plane: str = "tcp",
store_backend: str = "etcd", store_backend: str = "etcd",
durable_kv_events: bool = False, durable_kv_events: bool = False,
namespace: Optional[str] = None,
gpu_start_index: int = 0,
disaggregation_mode: Optional[str] = None,
): ):
"""Initialize SGLang workers with dynamo integration. """Initialize SGLang workers with dynamo integration.
...@@ -85,8 +89,10 @@ class SGLangProcess(ManagedEngineProcessMixin): ...@@ -85,8 +89,10 @@ class SGLangProcess(ManagedEngineProcessMixin):
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = namespace or f"test-namespace-{namespace_suffix}"
self.component_name = "backend" self.component_name = (
"prefill" if disaggregation_mode == "prefill" else "backend"
)
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
...@@ -116,10 +122,10 @@ class SGLangProcess(ManagedEngineProcessMixin): ...@@ -116,10 +122,10 @@ class SGLangProcess(ManagedEngineProcessMixin):
# Calculate GPU device for this process # Calculate GPU device for this process
if single_gpu: if single_gpu:
# Force all processes to GPU 0 (for single-GPU testing) # Force all processes to GPU 0 (for single-GPU testing)
gpu_device = "0" gpu_device = str(gpu_start_index)
elif data_parallel_size is not None: elif data_parallel_size is not None:
# Worker sees dp_rank GPUs (each DP rank gets its own GPU) # Worker sees dp_rank GPUs (each DP rank gets its own GPU)
worker_start_gpu = worker_idx * data_parallel_size worker_start_gpu = gpu_start_index + worker_idx * data_parallel_size
gpu_device = ",".join( gpu_device = ",".join(
str(i) str(i)
for i in range( for i in range(
...@@ -128,7 +134,7 @@ class SGLangProcess(ManagedEngineProcessMixin): ...@@ -128,7 +134,7 @@ class SGLangProcess(ManagedEngineProcessMixin):
) )
else: else:
# No DP; worker sees one GPU # No DP; worker sees one GPU
gpu_device = str(worker_idx) gpu_device = str(gpu_start_index + worker_idx)
command = [ command = [
"python3", "python3",
...@@ -152,6 +158,10 @@ class SGLangProcess(ManagedEngineProcessMixin): ...@@ -152,6 +158,10 @@ class SGLangProcess(ManagedEngineProcessMixin):
if context_length is not None: if context_length is not None:
command.extend(["--context-length", str(context_length)]) command.extend(["--context-length", str(context_length)])
if disaggregation_mode is not None:
command.extend(["--disaggregation-mode", disaggregation_mode])
command.extend(["--disaggregation-transfer-backend", "nixl"])
if data_parallel_size is not None: if data_parallel_size is not None:
# Add DP configuration # Add DP configuration
command.extend( command.extend(
...@@ -308,6 +318,40 @@ def test_router_decisions_sglang_dp( ...@@ -308,6 +318,40 @@ def test_router_decisions_sglang_dp(
) )
@pytest.mark.gpu_2
@pytest.mark.nightly
@pytest.mark.parametrize("request_plane", ["nats"], indirect=True)
@pytest.mark.timeout(600)
def test_router_decisions_sglang_disagg(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
):
run_disagg_router_decisions_test(
engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
request=request,
request_plane=request_plane,
model_name=MODEL_NAME,
block_size=PAGE_SIZE,
num_prefill_workers=2,
num_decode_workers=1,
prefill_process_kwargs={
"single_gpu": True,
"gpu_start_index": 0,
"disaggregation_mode": "prefill",
},
decode_process_kwargs={
"single_gpu": True,
"gpu_start_index": 1,
"disaggregation_mode": "decode",
},
)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
from tests.router.e2e_harness import ( from tests.router.e2e_harness import (
ManagedEngineProcessMixin, ManagedEngineProcessMixin,
run_basic_router_test, run_basic_router_test,
run_disagg_router_decisions_test,
run_indexers_sync_test, run_indexers_sync_test,
run_router_decisions_test, run_router_decisions_test,
) )
...@@ -63,6 +64,9 @@ class TRTLLMProcess(ManagedEngineProcessMixin): ...@@ -63,6 +64,9 @@ class TRTLLMProcess(ManagedEngineProcessMixin):
request_plane: str = "tcp", request_plane: str = "tcp",
store_backend: str = "etcd", store_backend: str = "etcd",
durable_kv_events: bool = False, durable_kv_events: bool = False,
namespace: Optional[str] = None,
gpu_start_index: int = 0,
disaggregation_mode: Optional[str] = None,
): ):
"""Initialize TRT-LLM workers with dynamo integration. """Initialize TRT-LLM workers with dynamo integration.
...@@ -91,8 +95,10 @@ class TRTLLMProcess(ManagedEngineProcessMixin): ...@@ -91,8 +95,10 @@ class TRTLLMProcess(ManagedEngineProcessMixin):
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = namespace or f"test-namespace-{namespace_suffix}"
self.component_name = "tensorrt_llm" self.component_name = (
"prefill" if disaggregation_mode == "prefill" else "tensorrt_llm"
)
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.worker_processes = [] self.worker_processes = []
...@@ -118,14 +124,16 @@ class TRTLLMProcess(ManagedEngineProcessMixin): ...@@ -118,14 +124,16 @@ class TRTLLMProcess(ManagedEngineProcessMixin):
# Calculate GPU device for this process # Calculate GPU device for this process
if single_gpu: if single_gpu:
# Force all processes to GPU 0 (for single-GPU testing) # Force all processes to GPU 0 (for single-GPU testing)
gpu_device = "0" gpu_device = str(gpu_start_index)
elif enable_attention_dp and tensor_parallel_size: elif enable_attention_dp and tensor_parallel_size:
# For attention DP, TRT-LLM spawns tensor_parallel_size internal MPI workers. # For attention DP, TRT-LLM spawns tensor_parallel_size internal MPI workers.
# So one process = two attention DP ranks = visibility in to both GPUs. # So one process = two attention DP ranks = visibility in to both GPUs.
gpu_device = ",".join(str(i) for i in range(tensor_parallel_size)) gpu_device = ",".join(
str(gpu_start_index + i) for i in range(tensor_parallel_size)
)
else: else:
# Each worker sees one GPU # Each worker sees one GPU
gpu_device = str(worker_idx) gpu_device = str(gpu_start_index + worker_idx)
# Single-node TRT-LLM workers use python3 -m dynamo.trtllm directly # Single-node TRT-LLM workers use python3 -m dynamo.trtllm directly
# (trtllm-llmapi-launch is only needed for multi-node MPI deployments) # (trtllm-llmapi-launch is only needed for multi-node MPI deployments)
...@@ -141,6 +149,9 @@ class TRTLLMProcess(ManagedEngineProcessMixin): ...@@ -141,6 +149,9 @@ class TRTLLMProcess(ManagedEngineProcessMixin):
"--publish-events-and-metrics", "--publish-events-and-metrics",
] ]
if disaggregation_mode is not None:
command.extend(["--disaggregation-mode", disaggregation_mode])
# Limit VRAM allocation (required for multi-worker on same GPU) # Limit VRAM allocation (required for multi-worker on same GPU)
if free_gpu_memory_fraction is not None: if free_gpu_memory_fraction is not None:
command.extend( command.extend(
...@@ -291,6 +302,40 @@ def test_router_decisions_trtllm_multiple_workers( ...@@ -291,6 +302,40 @@ def test_router_decisions_trtllm_multiple_workers(
) )
@pytest.mark.gpu_2
@pytest.mark.nightly
@pytest.mark.parametrize("request_plane", ["nats"], indirect=True)
@pytest.mark.timeout(600)
def test_router_decisions_trtllm_disagg(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
):
run_disagg_router_decisions_test(
engine_process_cls=TRTLLMProcess,
engine_args_name="trtllm_args",
engine_args=TRTLLM_ARGS,
request=request,
request_plane=request_plane,
model_name=MODEL_NAME,
block_size=TRTLLM_BLOCK_SIZE,
num_prefill_workers=2,
num_decode_workers=1,
prefill_process_kwargs={
"single_gpu": True,
"gpu_start_index": 0,
"disaggregation_mode": "prefill",
},
decode_process_kwargs={
"single_gpu": True,
"gpu_start_index": 1,
"disaggregation_mode": "decode",
},
)
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~45s/test), rounded up
......
...@@ -18,6 +18,7 @@ import pytest ...@@ -18,6 +18,7 @@ import pytest
from tests.router.e2e_harness import ( from tests.router.e2e_harness import (
ManagedEngineProcessMixin, ManagedEngineProcessMixin,
run_basic_router_test, run_basic_router_test,
run_disagg_router_decisions_test,
run_indexers_sync_test, run_indexers_sync_test,
run_router_decisions_test, run_router_decisions_test,
) )
...@@ -81,6 +82,9 @@ class VLLMProcess(ManagedEngineProcessMixin): ...@@ -81,6 +82,9 @@ class VLLMProcess(ManagedEngineProcessMixin):
request_plane: str = "tcp", request_plane: str = "tcp",
store_backend: str = "etcd", store_backend: str = "etcd",
durable_kv_events: bool = False, durable_kv_events: bool = False,
namespace: Optional[str] = None,
gpu_start_index: int = 0,
disaggregation_mode: Optional[str] = None,
standalone_indexer: bool = False, standalone_indexer: bool = False,
zmq_replay: bool = False, zmq_replay: bool = False,
): ):
...@@ -103,8 +107,10 @@ class VLLMProcess(ManagedEngineProcessMixin): ...@@ -103,8 +107,10 @@ class VLLMProcess(ManagedEngineProcessMixin):
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = namespace or f"test-namespace-{namespace_suffix}"
self.component_name = "backend" self.component_name = (
"prefill" if disaggregation_mode == "prefill" else "backend"
)
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_workers self.num_workers = num_workers
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
...@@ -170,10 +176,10 @@ class VLLMProcess(ManagedEngineProcessMixin): ...@@ -170,10 +176,10 @@ class VLLMProcess(ManagedEngineProcessMixin):
# Calculate GPU device for this process # Calculate GPU device for this process
if single_gpu: if single_gpu:
# Force all processes to GPU 0 (for single-GPU testing) # Force all processes to GPU 0 (for single-GPU testing)
gpu_device = "0" gpu_device = str(gpu_start_index)
elif data_parallel_size is not None: elif data_parallel_size is not None:
# Worker sees dp_rank GPUs (each DP rank gets its own GPU) # Worker sees dp_rank GPUs (each DP rank gets its own GPU)
worker_start_gpu = worker_idx * data_parallel_size worker_start_gpu = gpu_start_index + worker_idx * data_parallel_size
gpu_device = ",".join( gpu_device = ",".join(
str(i) str(i)
for i in range( for i in range(
...@@ -182,13 +188,22 @@ class VLLMProcess(ManagedEngineProcessMixin): ...@@ -182,13 +188,22 @@ class VLLMProcess(ManagedEngineProcessMixin):
) )
else: else:
# No DP; worker sees one GPU # No DP; worker sees one GPU
gpu_device = str(worker_idx) gpu_device = str(gpu_start_index + worker_idx)
command = ["python3", "-m", "dynamo.vllm", "--model", model] command = ["python3", "-m", "dynamo.vllm", "--model", model]
if "block_size" in vllm_args: if "block_size" in vllm_args:
command.extend(["--block-size", str(vllm_args["block_size"])]) command.extend(["--block-size", str(vllm_args["block_size"])])
if disaggregation_mode is not None:
command.extend(["--disaggregation-mode", disaggregation_mode])
command.extend(
[
"--kv-transfer-config",
'{"kv_connector":"NixlConnector","kv_role":"kv_both"}',
]
)
# Disable CUDA graphs for faster startup & lower memory # Disable CUDA graphs for faster startup & lower memory
if enforce_eager: if enforce_eager:
command.append("--enforce-eager") command.append("--enforce-eager")
...@@ -573,6 +588,40 @@ def test_router_decisions_vllm_dp( ...@@ -573,6 +588,40 @@ def test_router_decisions_vllm_dp(
) )
@pytest.mark.gpu_2
@pytest.mark.nightly
@pytest.mark.timeout(600)
@pytest.mark.parametrize("request_plane", ["nats"], indirect=True)
def test_router_decisions_vllm_disagg(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
):
run_disagg_router_decisions_test(
engine_process_cls=VLLMProcess,
engine_args_name="vllm_args",
engine_args=VLLM_ARGS,
request=request,
request_plane=request_plane,
model_name=MODEL_NAME,
block_size=BLOCK_SIZE,
num_prefill_workers=2,
num_decode_workers=1,
prefill_process_kwargs={
"single_gpu": True,
"gpu_start_index": 0,
"disaggregation_mode": "prefill",
},
decode_process_kwargs={
"single_gpu": True,
"gpu_start_index": 1,
"disaggregation_mode": "decode",
},
)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~43s/test), rounded up
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment