Unverified Commit 0b33c1df authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: sglang disagg routing fixes and optimizations [DYN-1692] (#5106)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarSean SH Choi <sechoi@nvidia.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent e49834c9
...@@ -84,7 +84,6 @@ def get_aiperf_cmd( ...@@ -84,7 +84,6 @@ def get_aiperf_cmd(
str(num_prefix_prompts), str(num_prefix_prompts),
"--artifact-dir", "--artifact-dir",
artifact_dir, artifact_dir,
"-v",
"-H", "-H",
"Authorization: Bearer NOT USED", "Authorization: Bearer NOT USED",
"-H", "-H",
......
...@@ -55,7 +55,6 @@ def get_aiperf_cmd_for_trace( ...@@ -55,7 +55,6 @@ def get_aiperf_cmd_for_trace(
str(seed), str(seed),
"--artifact-dir", "--artifact-dir",
artifact_dir, artifact_dir,
"-v",
"-H", "-H",
"Authorization: Bearer NOT USED", "Authorization: Bearer NOT USED",
"-H", "-H",
......
...@@ -161,8 +161,8 @@ def parse_args(): ...@@ -161,8 +161,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--router-max-tree-size", "--router-max-tree-size",
type=int, type=int,
default=int(os.environ.get("DYN_ROUTER_MAX_TREE_SIZE", str(2**10))), default=int(os.environ.get("DYN_ROUTER_MAX_TREE_SIZE", str(2**20))),
help="KV Router: Maximum tree size before pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_MAX_TREE_SIZE env var (default: 1024).", help="KV Router: Maximum tree size before pruning when KV events are disabled. Only used when --no-kv-events is set. Can be set via DYN_ROUTER_MAX_TREE_SIZE env var (default: 1048576, which is 2^20).",
) )
parser.add_argument( parser.add_argument(
"--router-prune-target-ratio", "--router-prune-target-ratio",
......
...@@ -237,8 +237,8 @@ def parse_args(): ...@@ -237,8 +237,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--router-max-tree-size", "--router-max-tree-size",
type=int, type=int,
default=2**10, default=2**20,
help="KV Router: Maximum tree size before pruning. Only used when --no-kv-events is set. When the indexer tree exceeds this size, pruning is triggered (default: 1024)", help="KV Router: Maximum tree size before pruning. Only used when --no-kv-events is set. When the indexer tree exceeds this size, pruning is triggered (default: 1048576, which is 2^20)",
) )
parser.add_argument( parser.add_argument(
......
...@@ -137,17 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,17 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile" "Registered engine routes: /engine/start_profile, /engine/stop_profile"
) )
# Create prefill client for disaggregated decode mode (fallback when --router-mode kv is not used)
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for disaggregated decode worker")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
...@@ -160,7 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -160,7 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client) handler = DecodeWorkerHandler(component, engine, config, publisher)
print(f"Config: {config}") print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload( health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer engine, use_text_input=dynamo_args.use_sglang_tokenizer
......
...@@ -15,7 +15,7 @@ import sglang as sgl ...@@ -15,7 +15,7 @@ import sglang as sgl
from sglang.srt.tracing import trace as sglang_trace from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component, Context from dynamo._core import Component, Context
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -30,7 +30,6 @@ class BaseWorkerHandler(ABC): ...@@ -30,7 +30,6 @@ class BaseWorkerHandler(ABC):
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
prefill_client: Optional[Client] = None,
) -> None: ) -> None:
"""Initialize base worker handler. """Initialize base worker handler.
...@@ -39,7 +38,6 @@ class BaseWorkerHandler(ABC): ...@@ -39,7 +38,6 @@ class BaseWorkerHandler(ABC):
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker. publisher: Optional metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
""" """
self.component = component self.component = component
self.engine = engine self.engine = engine
...@@ -50,7 +48,6 @@ class BaseWorkerHandler(ABC): ...@@ -50,7 +48,6 @@ class BaseWorkerHandler(ABC):
else: else:
self.metrics_publisher = None self.metrics_publisher = None
self.kv_publisher = None self.kv_publisher = None
self.prefill_client = prefill_client
self.serving_mode = config.serving_mode self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace self.enable_trace = config.server_args.enable_trace
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Any, AsyncGenerator, Dict, Optional from typing import Any, AsyncGenerator, Dict
import sglang as sgl import sglang as sgl
from dynamo._core import Client, Component, Context from dynamo._core import Component, Context
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -24,7 +23,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -24,7 +23,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
) -> None: ) -> None:
"""Initialize decode worker handler. """Initialize decode worker handler.
...@@ -33,14 +31,12 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -33,14 +31,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker. publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
""" """
super().__init__( super().__init__(
component, component,
engine, engine,
config, config,
publisher, publisher,
prefill_client,
) )
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
logging.info( logging.info(
...@@ -108,52 +104,16 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -108,52 +104,16 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request) input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# Check if bootstrap_info is pre-computed in the request (from frontend with --router-mode kv) # Check if bootstrap_info is pre-computed in the request (from frontend)
bootstrap_info = request.get("bootstrap_info") bootstrap_info = request.get("bootstrap_info")
if not bootstrap_info: if not bootstrap_info:
# Fallback: fetch bootstrap_info from prefill worker via round-robin routing
if self.prefill_client is None:
raise RuntimeError( raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided, " "bootstrap_info is required for disaggregated decode but was not provided"
"and no prefill_client is available for fallback."
) )
logging.debug( logging.debug(
"No bootstrap_info in request, fetching from prefill worker" f"Using bootstrap_info: "
)
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
prefill_response = None
async for info in prefill_stream:
prefill_response = info.data()
break
if not prefill_response:
raise RuntimeError("No response received from prefill worker")
# Extract bootstrap_info from disaggregated_params (PrefillWorkerHandler format)
bootstrap_info = prefill_response.get("disaggregated_params")
if not bootstrap_info:
raise RuntimeError(
"No bootstrap info (disaggregated_params) received from prefill worker"
)
logging.debug(
f"Received bootstrap_info from prefill worker: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
else:
logging.debug(
f"Using pre-computed bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, " f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, " f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}" f"room={bootstrap_info['bootstrap_room']}"
......
...@@ -84,11 +84,12 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -84,11 +84,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
k: v for k, v in sampling_params.items() if v is not None k: v for k, v in sampling_params.items() if v is not None
} }
# Use provided bootstrap_room if available, otherwise generate one # Use provided bootstrap_room from bootstrap_info if available, otherwise generate one
bootstrap_room = None bootstrap_room = None
extra_args = inner_request.get("extra_args", {}) bootstrap_info_from_req = inner_request.get("bootstrap_info")
if isinstance(extra_args, dict): if isinstance(bootstrap_info_from_req, dict):
bootstrap_room = extra_args.get("bootstrap_room") bootstrap_room = bootstrap_info_from_req.get("bootstrap_room")
if bootstrap_room is not None:
logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}") logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}")
if bootstrap_room is None: if bootstrap_room is None:
...@@ -130,6 +131,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -130,6 +131,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self._consume_tasks.add(task) self._consume_tasks.add(task)
task.add_done_callback(self._consume_tasks.discard) task.add_done_callback(self._consume_tasks.discard)
await task
async def _consume_results( async def _consume_results(
self, results: AsyncGenerator[Any, None], context: Context self, results: AsyncGenerator[Any, None], context: Context
) -> None: ) -> None:
......
...@@ -52,7 +52,7 @@ The main KV-aware routing arguments: ...@@ -52,7 +52,7 @@ The main KV-aware routing arguments:
> If you run with `DYN_REQUEST_PLANE=tcp` (or `http`) and KV events enabled (default), you must also configure NATS, e.g. `NATS_SERVER=nats://...`. > If you run with `DYN_REQUEST_PLANE=tcp` (or `http`) and KV events enabled (default), you must also configure NATS, e.g. `NATS_SERVER=nats://...`.
> Only `--no-kv-events` removes the NATS requirement. > Only `--no-kv-events` removes the NATS requirement.
> >
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases. > When `--kv-overlap-score-weight` is set to 0, no KvIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KvIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning. In both cases, it's recommended to disable your backend workers from publishing events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
> >
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored. > The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
echo "Cleaning up background processes..." echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 $DECODE_PID2 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 $DECODE_PID2 2>/dev/null || true
echo "Cleanup complete." echo "Cleanup complete."
} }
trap cleanup EXIT INT TERM trap cleanup EXIT INT TERM
...@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do ...@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
echo "" echo ""
echo "Note: System metrics are enabled by default on ports:" echo "Note: System metrics are enabled by default on ports:"
echo " 8082-8083 (prefill workers), 8084-8085 (decode workers)" echo " 8081-8082 (prefill workers), 8083-8084 (decode workers)"
exit 0 exit 0
;; ;;
*) *)
...@@ -48,6 +48,7 @@ fi ...@@ -48,6 +48,7 @@ fi
# Start frontend with KV routing # Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router # The frontend will automatically detect prefill workers and activate an internal prefill router
# No standalone prefill router needed - the frontend handles prefill routing internally
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) # dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
OTEL_SERVICE_NAME=dynamo-frontend \ OTEL_SERVICE_NAME=dynamo-frontend \
python3 -m dynamo.frontend \ python3 -m dynamo.frontend \
...@@ -55,19 +56,8 @@ python3 -m dynamo.frontend \ ...@@ -55,19 +56,8 @@ python3 -m dynamo.frontend \
--router-reset-states & --router-reset-states &
DYNAMO_PID=$! DYNAMO_PID=$!
# run prefill router
# Use numeric DYN_SYSTEM_PORT{N} env vars so launchers/test harnesses can set
# ports without encoding role names (prefill/decode) in the env var.
OTEL_SERVICE_NAME=dynamo-router-prefill DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
python3 -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks &
PREFILL_ROUTER_PID=$!
# run prefill worker # run prefill worker
OTEL_SERVICE_NAME=dynamo-worker-prefill-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \ OTEL_SERVICE_NAME=dynamo-worker-prefill-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
python3 -m dynamo.sglang \ python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
...@@ -80,10 +70,10 @@ python3 -m dynamo.sglang \ ...@@ -80,10 +70,10 @@ python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--enable-metrics \ --enable-metrics \
"${TRACE_ARGS[@]}" & "${TRACE_ARGS[@]}" &
PREFILL_PID=$! PREFILL_PID1=$!
# run prefill worker # run prefill worker
OTEL_SERVICE_NAME=dynamo-worker-prefill-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT3:-8083} \ OTEL_SERVICE_NAME=dynamo-worker-prefill-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
...@@ -96,10 +86,10 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ ...@@ -96,10 +86,10 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--enable-metrics \ --enable-metrics \
"${TRACE_ARGS[@]}" & "${TRACE_ARGS[@]}" &
PREFILL_PID=$! PREFILL_PID2=$!
# run decode worker # run decode worker
OTEL_SERVICE_NAME=dynamo-worker-decode-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT4:-8084} \ OTEL_SERVICE_NAME=dynamo-worker-decode-1 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT3:-8083} \
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
...@@ -112,10 +102,10 @@ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \ ...@@ -112,10 +102,10 @@ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--enable-metrics \ --enable-metrics \
"${TRACE_ARGS[@]}" & "${TRACE_ARGS[@]}" &
PREFILL_PID=$! DECODE_PID1=$!
# run decode worker # run decode worker
OTEL_SERVICE_NAME=dynamo-worker-decode-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT5:-8085} \ OTEL_SERVICE_NAME=dynamo-worker-decode-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT4:-8084} \
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
...@@ -127,4 +117,8 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \ ...@@ -127,4 +117,8 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--enable-metrics \ --enable-metrics \
"${TRACE_ARGS[@]}" "${TRACE_ARGS[@]}" &
DECODE_PID2=$!
# Wait for any worker to exit (keeps script running)
wait
...@@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; ...@@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use dynamo_llm::{ use dynamo_llm::{
discovery::{KvWorkerMonitor, ModelWatcher}, discovery::{KvWorkerMonitor, ModelWatcher},
kv_router::{indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher}, kv_router::{protocols::*, publisher::KvEventPublisher},
}; };
use dynamo_runtime::{DistributedRuntime, Worker}; use dynamo_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
...@@ -1475,6 +1475,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1475,6 +1475,7 @@ pub async fn create_worker_selection_pipeline_chat(
>( >(
&card_with_local_files, &card_with_local_files,
&client, &client,
model_manager.clone(),
router_mode, router_mode,
worker_monitor, worker_monitor,
chooser, chooser,
......
...@@ -1587,6 +1587,7 @@ dependencies = [ ...@@ -1587,6 +1587,7 @@ dependencies = [
"tokio-util", "tokio-util",
"tracing", "tracing",
"url", "url",
"utoipa",
"uuid", "uuid",
] ]
...@@ -7792,6 +7793,8 @@ dependencies = [ ...@@ -7792,6 +7793,8 @@ dependencies = [
"quote", "quote",
"regex", "regex",
"syn 2.0.110", "syn 2.0.110",
"url",
"uuid",
] ]
[[package]] [[package]]
......
...@@ -50,7 +50,7 @@ impl KvRouterConfig { ...@@ -50,7 +50,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
......
...@@ -11,11 +11,11 @@ use tokio_stream::StreamExt; ...@@ -11,11 +11,11 @@ use tokio_stream::StreamExt;
use super::*; use super::*;
use crate::Component; use crate::Component;
use llm_rs::kv_router::indexer::KvIndexerInterface; use llm_rs::kv_router::indexer::KvIndexerInterface;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics; use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
use llm_rs::kv_router::protocols::KvStats as RsKvStats; use llm_rs::kv_router::protocols::KvStats as RsKvStats;
use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats; use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats;
use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats; use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn}; use rs::pipeline::{AsyncEngine, SingleIn};
use rs::traits::events::EventSubscriber; use rs::traits::events::EventSubscriber;
use tracing; use tracing;
...@@ -782,7 +782,7 @@ pub(crate) struct ApproxKvIndexer { ...@@ -782,7 +782,7 @@ pub(crate) struct ApproxKvIndexer {
#[pymethods] #[pymethods]
impl ApproxKvIndexer { impl ApproxKvIndexer {
#[new] #[new]
#[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))] #[pyo3(signature = (component, kv_block_size, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
fn new( fn new(
component: Component, component: Component,
kv_block_size: usize, kv_block_size: usize,
...@@ -851,10 +851,12 @@ impl ApproxKvIndexer { ...@@ -851,10 +851,12 @@ impl ApproxKvIndexer {
dp_rank: DpRank, dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone(); let indexer = self.inner.clone();
let block_size = self.inner.block_size();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank); let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
let mut tokens_with_hashes = TokensWithHashes::new(tokens, block_size);
indexer indexer
.process_routing_decision_for_request(tokens.as_slice(), worker) .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -684,7 +684,7 @@ class ApproxKvIndexer: ...@@ -684,7 +684,7 @@ class ApproxKvIndexer:
component: Component, component: Component,
kv_block_size: int, kv_block_size: int,
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
) -> None: ) -> None:
""" """
...@@ -694,7 +694,7 @@ class ApproxKvIndexer: ...@@ -694,7 +694,7 @@ class ApproxKvIndexer:
component: The component to associate with this indexer component: The component to associate with this indexer
kv_block_size: The KV cache block size kv_block_size: The KV cache block size
router_ttl_secs: TTL for blocks in seconds (default: 120.0) router_ttl_secs: TTL for blocks in seconds (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
""" """
... ...
...@@ -1091,7 +1091,7 @@ class KvRouterConfig: ...@@ -1091,7 +1091,7 @@ class KvRouterConfig:
router_snapshot_threshold: Optional[int] = 1000000, router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False, router_reset_states: bool = False,
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1024, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
) -> None: ) -> None:
""" """
...@@ -1106,7 +1106,7 @@ class KvRouterConfig: ...@@ -1106,7 +1106,7 @@ class KvRouterConfig:
router_snapshot_threshold: Number of messages before snapshot (default: 1000000) router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False) router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1024) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
""" """
... ...
......
...@@ -6,6 +6,7 @@ use std::{ ...@@ -6,6 +6,7 @@ use std::{
sync::Arc, sync::Arc,
}; };
use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot; use tokio::sync::oneshot;
...@@ -13,13 +14,17 @@ use crate::discovery::KvWorkerMonitor; ...@@ -13,13 +14,17 @@ use crate::discovery::KvWorkerMonitor;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type}, component::{Client, Endpoint, build_transport_type},
discovery::DiscoverySpec, discovery::{DiscoveryQuery, DiscoverySpec, watch_and_extract_field},
prelude::DistributedRuntimeProvider, prelude::DistributedRuntimeProvider,
protocols::EndpointId, protocols::EndpointId,
}; };
use crate::{ use crate::{
kv_router::{KvRouter, KvRouterConfig, router_endpoint_id, scheduler::DefaultWorkerSelector}, kv_router::{
KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
scheduler::DefaultWorkerSelector,
},
local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig},
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
model_type::ModelType, model_type::ModelType,
types::{ types::{
...@@ -73,6 +78,11 @@ pub struct ModelManager { ...@@ -73,6 +78,11 @@ pub struct ModelManager {
/// Key: model name, Value: cloneable monitor (all fields are Arc). /// Key: model name, Value: cloneable monitor (all fields are Arc).
/// HTTP endpoint can update thresholds via monitor.set_threshold(). /// HTTP endpoint can update thresholds via monitor.set_threshold().
worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>, worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Inner Arc<DashMap>: keyed by WorkerId, shared with KvScheduler
runtime_configs: DashMap<EndpointId, Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
} }
impl Default for ModelManager { impl Default for ModelManager {
...@@ -93,6 +103,7 @@ impl ModelManager { ...@@ -93,6 +103,7 @@ impl ModelManager {
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()), prefill_router_activators: Mutex::new(HashMap::new()),
worker_monitors: RwLock::new(HashMap::new()), worker_monitors: RwLock::new(HashMap::new()),
runtime_configs: DashMap::new(),
} }
} }
...@@ -360,10 +371,14 @@ impl ModelManager { ...@@ -360,10 +371,14 @@ impl ModelManager {
// Use instance_id (hex) as the consumer ID for NATS consumer coordination // Use instance_id (hex) as the consumer ID for NATS consumer coordination
let consumer_id = instance_id.to_string(); let consumer_id = instance_id.to_string();
// Get or create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new( let chooser = KvRouter::new(
endpoint.clone(), endpoint.clone(),
client, client,
workers_with_configs,
kv_cache_block_size, kv_cache_block_size,
Some(selector), Some(selector),
kv_router_config, kv_router_config,
...@@ -602,6 +617,146 @@ impl ModelManager { ...@@ -602,6 +617,146 @@ impl ModelManager {
self.worker_monitors.read().get(model).cloned() self.worker_monitors.read().get(model).cloned()
} }
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch DiscoveryQuery::EndpointModels.
/// Returns a shared Arc<DashMap> that KvScheduler can use directly.
pub async fn get_or_create_runtime_config_watcher(
&self,
endpoint: &Endpoint,
) -> anyhow::Result<Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>> {
let endpoint_id = endpoint.id();
// Fast path: return existing if present
if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
return Ok(existing.clone());
}
// Atomic get-or-insert to avoid TOCTOU race
let inner_map = Arc::new(DashMap::new());
let (map, is_new) = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => (e.get().clone(), false),
Entry::Vacant(e) => {
e.insert(inner_map.clone());
(inner_map, true)
}
};
// Only spawn watcher if we were the one who inserted
if is_new {
self.spawn_runtime_config_watcher(endpoint, map.clone())
.await?;
}
Ok(map)
}
/// Get disaggregated endpoint for a specific worker.
/// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
pub fn get_disaggregated_endpoint(
&self,
endpoint_id: &EndpointId,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let inner_map = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner_map.get(&worker_id)?;
config_ref.as_ref()?.disaggregated_endpoint.clone()
}
/// Spawn background task to watch runtime configs via discovery.
async fn spawn_runtime_config_watcher(
&self,
endpoint: &Endpoint,
inner_map: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
) -> anyhow::Result<()> {
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
// Set up discovery watch for EndpointModels
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
.await?;
// Extract runtime_config from ModelDeploymentCard
let mut runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Also watch instance IDs
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Spawn background task to update inner_map
let cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("ModelManager runtime config watcher started");
loop {
// Wait for either instances or configs to change
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("ModelManager runtime config watcher shutting down");
break;
}
result = instance_ids_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown in ModelManager");
break;
}
}
result = runtime_configs_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown in ModelManager");
break;
}
}
}
// Get the latest values from both channels
let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
let new_configs = runtime_configs_rx.borrow_and_update().clone();
// Update the DashMap
// First, remove workers that no longer exist
let current_workers: HashSet<WorkerId> =
inner_map.iter().map(|r| *r.key()).collect();
let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
for removed_worker in current_workers.difference(&new_workers) {
inner_map.remove(removed_worker);
}
// Then, add/update workers
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
let prev_config = inner_map.get(worker_id);
if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
tracing::info!(
"ModelManager: Runtime config found for worker_id: {}",
worker_id
);
}
}
inner_map.insert(*worker_id, config);
}
tracing::trace!(
"ModelManager: Updated runtime_configs with {} workers",
inner_map.len()
);
}
tracing::trace!("ModelManager runtime config watcher shutting down");
});
Ok(())
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured. /// Lists all models that have worker monitors (and thus busy thresholds) configured.
/// ///
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples. /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
......
...@@ -449,6 +449,7 @@ impl ModelWatcher { ...@@ -449,6 +449,7 @@ impl ModelWatcher {
>( >(
card, card,
&client, &client,
self.manager.clone(),
self.router_config.router_mode, self.router_config.router_mode,
worker_monitor.clone(), worker_monitor.clone(),
kv_chooser.clone(), kv_chooser.clone(),
...@@ -482,6 +483,7 @@ impl ModelWatcher { ...@@ -482,6 +483,7 @@ impl ModelWatcher {
>( >(
card, card,
&client, &client,
self.manager.clone(),
self.router_config.router_mode, self.router_config.router_mode,
worker_monitor, worker_monitor,
kv_chooser, kv_chooser,
......
...@@ -172,6 +172,7 @@ where ...@@ -172,6 +172,7 @@ where
pub async fn build_routed_pipeline<Req, Resp>( pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
model_manager: Arc<crate::discovery::ModelManager>,
router_mode: RouterMode, router_mode: RouterMode,
worker_monitor: Option<KvWorkerMonitor>, worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
...@@ -196,6 +197,7 @@ where ...@@ -196,6 +197,7 @@ where
build_routed_pipeline_with_preprocessor( build_routed_pipeline_with_preprocessor(
card, card,
client, client,
model_manager,
router_mode, router_mode,
worker_monitor, worker_monitor,
chooser, chooser,
...@@ -212,6 +214,7 @@ where ...@@ -212,6 +214,7 @@ where
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
model_manager: Arc<crate::discovery::ModelManager>,
router_mode: RouterMode, router_mode: RouterMode,
worker_monitor: Option<KvWorkerMonitor>, worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
...@@ -277,8 +280,8 @@ where ...@@ -277,8 +280,8 @@ where
}; };
// Use the provided prefill chooser, or create a disabled one if not provided // Use the provided prefill chooser, or create a disabled one if not provided
let prefill_chooser = let prefill_chooser = prefill_chooser
prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg)); .unwrap_or_else(|| PrefillRouter::disabled(model_manager, router_mode, enforce_disagg));
let prefill_op = prefill_chooser.into_operator(); let prefill_op = prefill_chooser.into_operator();
// Link with prefill chooser including backward edge for response flow // Link with prefill chooser including backward edge for response flow
......
...@@ -6,6 +6,7 @@ use std::sync::Arc; ...@@ -6,6 +6,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint}, component::{Client, Endpoint},
...@@ -40,13 +41,11 @@ use worker_query::WorkerQueryClient; ...@@ -40,13 +41,11 @@ use worker_query::WorkerQueryClient;
use crate::{ use crate::{
kv_router::{ kv_router::{
approx::PruneConfig, approx::PruneConfig,
indexer::{ indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent},
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{ protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult, LocalBlockHash, RouterRequest, RouterResponse, TokensWithHashes, WorkerId,
WorkerWithDpRank, WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
compute_seq_hash_for_block,
}, },
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError, sequence::SequenceError,
...@@ -57,7 +56,6 @@ use crate::{ ...@@ -57,7 +56,6 @@ use crate::{
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
protocols::common::timing::RequestPhase, protocols::common::timing::RequestPhase,
tokens::SequenceHash,
}; };
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
...@@ -148,7 +146,7 @@ pub struct KvRouterConfig { ...@@ -148,7 +146,7 @@ pub struct KvRouterConfig {
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0) /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
pub router_ttl_secs: f64, pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024) /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
pub router_max_tree_size: usize, pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8) /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
...@@ -166,7 +164,7 @@ impl Default for KvRouterConfig { ...@@ -166,7 +164,7 @@ impl Default for KvRouterConfig {
router_snapshot_threshold: Some(1000000), router_snapshot_threshold: Some(1000000),
router_reset_states: false, router_reset_states: false,
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 1024, router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
} }
} }
...@@ -244,16 +242,15 @@ impl Indexer { ...@@ -244,16 +242,15 @@ impl Indexer {
} }
} }
async fn process_routing_decision( async fn process_routing_decision_for_request(
&self, &self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
match self { match self {
Indexer::KvIndexer(indexer) => { Indexer::KvIndexer(indexer) => {
indexer indexer
.process_routing_decision(worker, local_hashes, sequence_hashes) .process_routing_decision_for_request(tokens_with_hashes, worker)
.await .await
} }
Indexer::None => Ok(()), Indexer::None => Ok(()),
...@@ -284,6 +281,7 @@ impl KvRouter { ...@@ -284,6 +281,7 @@ impl KvRouter {
pub async fn new( pub async fn new(
endpoint: Endpoint, endpoint: Endpoint,
client: Client, client: Client,
workers_with_configs: Arc<DashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>>,
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
...@@ -296,6 +294,7 @@ impl KvRouter { ...@@ -296,6 +294,7 @@ impl KvRouter {
let instance_ids_rx = client.instance_avail_watcher(); let instance_ids_rx = client.instance_avail_watcher();
// Watch for runtime config updates via discovery interface // Watch for runtime config updates via discovery interface
// (still needed for WorkerQueryClient and background tasks)
let discovery = component.drt().discovery(); let discovery = component.drt().discovery();
let endpoint_id = endpoint.id(); let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels { let discovery_key = DiscoveryQuery::EndpointModels {
...@@ -341,7 +340,7 @@ impl KvRouter { ...@@ -341,7 +340,7 @@ impl KvRouter {
component.clone(), component.clone(),
block_size, block_size,
instance_ids_rx, instance_ids_rx,
runtime_configs_rx.clone(), workers_with_configs,
selector, selector,
kv_router_config.router_replica_sync, kv_router_config.router_replica_sync,
consumer_id.clone(), consumer_id.clone(),
...@@ -476,41 +475,28 @@ impl KvRouter { ...@@ -476,41 +475,28 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes // Compute seq_hashes only if scheduler needs it for active blocks tracking
let needs_process_routing = !self.kv_router_config.use_kv_events; let maybe_seq_hashes = self
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks; .kv_router_config
.router_track_active_blocks
// Optimize cloning: only clone if both need it, otherwise move .then(|| compute_seq_hash_for_block(&block_hashes));
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (needs_process_routing, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};
let best_worker = self let best_worker = self
.scheduler .scheduler
.schedule( .schedule(
context_id.map(|s| s.to_string()), context_id.map(|s| s.to_string()),
isl_tokens, isl_tokens,
maybe_seq_hashes_2, maybe_seq_hashes,
overlap_scores.clone(), overlap_scores.clone(),
router_config_override, router_config_override,
update_states, update_states,
) )
.await?; .await?;
// Process routing decision when not using KV events (approximate mode with TTL/pruning) // Note: Routing decision recording (for approximate mode) is now handled
if needs_process_routing { // by KvPushRouter::generate after select_worker returns.
self.indexer
.process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await?;
}
let overlap_amount = overlap_scores let overlap_amount = overlap_scores
.scores .scores
...@@ -573,25 +559,16 @@ impl KvRouter { ...@@ -573,25 +559,16 @@ impl KvRouter {
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0)) Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
} }
/// Get the disaggregated endpoint for a worker, if available.
/// Used to look up bootstrap host/port for prefill workers.
pub async fn get_disaggregated_endpoint(
&self,
worker_id: u64,
) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
self.scheduler.get_disaggregated_endpoint(worker_id).await
}
/// Get potential prefill and decode loads for all workers /// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> { pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| { let maybe_seq_hashes = self
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); .kv_router_config
compute_seq_hash_for_block(&block_hashes) .router_track_active_blocks
}); .then(|| compute_seq_hash_for_block(&block_hashes));
Ok(self Ok(self
.scheduler .scheduler
...@@ -838,14 +815,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -838,14 +815,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let is_query_only = request.get_annotation_value("query_instance_id").is_some(); let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.) // Determine if this router should handle local state updates (add_request, free, etc.)
// When routing hints are present, the external caller handles state tracking // Only skip local updates for GAIE Stage 2: when BOTH prefill and decode worker IDs
// via separate API calls, so we skip local updates here. // are externally specified (indicates external orchestrator handles tracking).
// For internal routing (e.g., bootstrap optimization with only prefill_worker_id set),
// we still handle updates locally.
let routing = request.routing.as_ref(); let routing = request.routing.as_ref();
let handle_local_updates = routing let handle_local_updates = routing
.map(|r| { .map(|r| {
// No routing hints = we handle updates locally // GAIE Stage 2 sets both worker IDs - external caller handles tracking
r.backend_instance_id.is_none() // All other cases (including backend_instance_id for routing) - we handle locally
&& (r.prefill_worker_id.is_none() || r.decode_worker_id.is_none()) r.prefill_worker_id.is_none() || r.decode_worker_id.is_none()
}) })
.unwrap_or(true); .unwrap_or(true);
...@@ -872,6 +851,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -872,6 +851,29 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
overlap_amount, overlap_amount,
} = selection; } = selection;
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
if let Err(e) = self
.chooser
.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
}
// Record metrics in tracker: KV hit rate and worker ID based on phase // Record metrics in tracker: KV hit rate and worker ID based on phase
if let Some(ref tracker) = request.tracker { if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size); let isl_blocks = request.token_ids.len().div_ceil(block_size);
...@@ -936,11 +938,18 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -936,11 +938,18 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}; };
if !prefill_marked { if !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await { if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}"); tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
} }
prefill_marked = true; prefill_marked = true;
} }
}
yield item; yield item;
} }
......
...@@ -223,7 +223,7 @@ impl<K: Clone + Hash + Eq + Ord> PruneManager<K> { ...@@ -223,7 +223,7 @@ impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
mod tests { mod tests {
use super::*; use super::*;
use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics}; use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use crate::kv_router::protocols::{WorkerId, WorkerWithDpRank}; use crate::kv_router::protocols::{TokensWithHashes, WorkerId, WorkerWithDpRank};
use std::sync::Arc; use std::sync::Arc;
use tokio::time::{self, Duration, Instant}; use tokio::time::{self, Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -355,9 +355,10 @@ mod tests { ...@@ -355,9 +355,10 @@ mod tests {
assert!(pre_scores.scores.is_empty()); assert!(pre_scores.scores.is_empty());
// 2. Inform indexer about routing decision // 2. Inform indexer about routing decision
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_id), WorkerWithDpRank::from_worker_id(worker_id),
) )
.await .await
...@@ -401,9 +402,10 @@ mod tests { ...@@ -401,9 +402,10 @@ mod tests {
let tokens: Vec<u32> = vec![10, 11, 12, 13]; let tokens: Vec<u32> = vec![10, 11, 12, 13];
let worker_id: WorkerId = 7; let worker_id: WorkerId = 7;
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_id), WorkerWithDpRank::from_worker_id(worker_id),
) )
.await .await
...@@ -454,16 +456,18 @@ mod tests { ...@@ -454,16 +456,18 @@ mod tests {
let worker_1: WorkerId = 31; let worker_1: WorkerId = 31;
// Register on both workers // Register on both workers
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_0), WorkerWithDpRank::from_worker_id(worker_0),
) )
.await .await
.unwrap(); .unwrap();
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_1), WorkerWithDpRank::from_worker_id(worker_1),
) )
.await .await
...@@ -524,9 +528,10 @@ mod tests { ...@@ -524,9 +528,10 @@ mod tests {
let worker_a: WorkerId = 11; let worker_a: WorkerId = 11;
// Register Sequence A on worker A // Register Sequence A on worker A
let mut tokens_with_hashes = TokensWithHashes::new(seq_a.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&seq_a, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_a), WorkerWithDpRank::from_worker_id(worker_a),
) )
.await .await
...@@ -582,16 +587,18 @@ mod tests { ...@@ -582,16 +587,18 @@ mod tests {
let worker_1: WorkerId = 22; let worker_1: WorkerId = 22;
// Register the same sequence on two different workers // Register the same sequence on two different workers
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_0), WorkerWithDpRank::from_worker_id(worker_0),
) )
.await .await
.unwrap(); .unwrap();
let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request( .process_routing_decision_for_request(
&tokens, &mut tokens_with_hashes,
WorkerWithDpRank::from_worker_id(worker_1), WorkerWithDpRank::from_worker_id(worker_1),
) )
.await .await
...@@ -759,8 +766,9 @@ mod tests { ...@@ -759,8 +766,9 @@ mod tests {
// Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding) // Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
for i in 0..5 { for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request(&tokens, worker) .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await .await
.unwrap(); .unwrap();
time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
...@@ -780,8 +788,9 @@ mod tests { ...@@ -780,8 +788,9 @@ mod tests {
// Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning // Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
let tokens: Vec<u32> = vec![50, 51, 52, 53]; let tokens: Vec<u32> = vec![50, 51, 52, 53];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
indexer indexer
.process_routing_decision_for_request(&tokens, worker) .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await .await
.unwrap(); .unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment