Unverified Commit c2a29f80 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: enable local indexers for sglang and trtllm (#4932)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent feb49914
......@@ -116,6 +116,13 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": os.environ.get("DYN_REQUEST_PLANE", "nats"),
"help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
},
"enable-local-indexer": {
"flags": ["--enable-local-indexer"],
"type": str,
"choices": ["true", "false"],
"default": os.environ.get("DYN_LOCAL_INDEXER", "false"),
"help": "Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
},
}
......@@ -148,6 +155,8 @@ class DynamoArgs:
embedding_worker: bool = False
# config dump options
dump_config_to: Optional[str] = None
# local indexer option
enable_local_indexer: bool = False
class DisaggregationMode(Enum):
......@@ -477,6 +486,7 @@ async def parse_args(args: list[str]) -> Config:
multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
)
logging.debug(f"Dynamo args: {dynamo_args}")
......
......@@ -74,6 +74,7 @@ class DynamoSglangPublisher:
"""
self.engine = engine
self.server_args = config.server_args
self.dynamo_args = config.dynamo_args
self.generate_endpoint = generate_endpoint
self.component = component
self.metrics_publisher = WorkerMetricsPublisher()
......@@ -151,6 +152,7 @@ class DynamoSglangPublisher:
worker_id=self.generate_endpoint.connection_id(),
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
self.kv_publisher = ZmqKvEventPublisher(
......
......@@ -82,6 +82,7 @@ async def _get_runtime_config(
# set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.reasoning_parser
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# In SGLang, these are server_args, not scheduler_info (unlike vLLM)
# Note: If --max-running-requests is not specified, SGLang uses an internal default
......
......@@ -351,6 +351,7 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.max_num_batched_tokens = config.max_num_tokens
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.enable_local_indexer = config.enable_local_indexer
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
logging.info(
......@@ -458,6 +459,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.kv_block_size,
metrics_labels,
zmq_endpoint=trtllm_zmq_bind_endpoint,
enable_local_indexer=config.enable_local_indexer,
) as publisher:
handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config)
......
......@@ -276,6 +276,7 @@ class Publisher:
kv_block_size,
metrics_labels,
zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
):
self.component = component
self.engine = engine
......@@ -284,6 +285,7 @@ class Publisher:
self.kv_block_size = kv_block_size
self.max_window_size = None
self.metrics_labels = metrics_labels
self.enable_local_indexer = enable_local_indexer
# The first few kv events from the model engine are always "created" type events.
# Use these events to capture the max_window_size of the model.
......@@ -348,7 +350,11 @@ class Publisher:
else:
# No consolidator: use NATS publisher (router subscribes directly)
self.kv_event_publisher = KvEventPublisher(
self.kv_listener, self.worker_id, self.kv_block_size, dp_rank=0
self.kv_listener,
self.worker_id,
self.kv_block_size,
dp_rank=0,
enable_local_indexer=self.enable_local_indexer,
)
# Always initialize the thread - it routes to either ZMQ or NATS publisher
......@@ -685,6 +691,7 @@ async def get_publisher(
kv_block_size,
metrics_labels,
zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
):
publisher = Publisher(
component,
......@@ -694,6 +701,7 @@ async def get_publisher(
kv_block_size,
metrics_labels,
zmq_endpoint=zmq_endpoint,
enable_local_indexer=enable_local_indexer,
)
try:
publisher.initialize()
......
......@@ -61,6 +61,7 @@ class Config:
self.dyn_endpoint_types: str = "chat,completions"
self.store_kv: str = ""
self.request_plane: str = ""
self.enable_local_indexer: bool = False
def __str__(self) -> str:
return (
......@@ -93,7 +94,8 @@ class Config:
f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}, "
f"request_plane={self.request_plane}"
f"request_plane={self.request_plane}, "
f"enable_local_indexer={self.enable_local_indexer}"
)
......@@ -303,6 +305,13 @@ def cmd_line_args():
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--enable-local-indexer",
type=str,
choices=["true", "false"],
default=os.environ.get("DYN_LOCAL_INDEXER", "false"),
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
)
args = parser.parse_args()
......@@ -365,6 +374,7 @@ def cmd_line_args():
config.dyn_endpoint_types = args.dyn_endpoint_types
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
......
......@@ -245,12 +245,13 @@ pub(crate) struct KvEventPublisher {
#[pymethods]
impl KvEventPublisher {
#[new]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0, enable_local_indexer=false))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
enable_local_indexer: bool,
) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
......@@ -260,10 +261,11 @@ impl KvEventPublisher {
// The actual worker_id is inferred from component's connection_id in the Rust implementation.
let _ = worker_id;
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
component.inner,
kv_block_size as u32,
None,
enable_local_indexer,
)
.map_err(to_pyerr)?;
......
......@@ -794,7 +794,7 @@ class KvEventPublisher:
...
def __init__(
self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0
self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0, enable_local_indexer: bool = False
) -> None:
"""
Create a `KvEventPublisher` object
......@@ -804,6 +804,7 @@ class KvEventPublisher:
worker_id: The worker ID
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
enable_local_indexer: Enable worker-local KV indexer (defaults to False)
"""
def publish_stored(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment