Unverified Commit c2a29f80 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: enable local indexers for sglang and trtllm (#4932)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent feb49914
...@@ -116,6 +116,13 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -116,6 +116,13 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": os.environ.get("DYN_REQUEST_PLANE", "nats"), "default": os.environ.get("DYN_REQUEST_PLANE", "nats"),
"help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", "help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
}, },
"enable-local-indexer": {
"flags": ["--enable-local-indexer"],
"type": str,
"choices": ["true", "false"],
"default": os.environ.get("DYN_LOCAL_INDEXER", "false"),
"help": "Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
},
} }
...@@ -148,6 +155,8 @@ class DynamoArgs: ...@@ -148,6 +155,8 @@ class DynamoArgs:
embedding_worker: bool = False embedding_worker: bool = False
# config dump options # config dump options
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
# local indexer option
enable_local_indexer: bool = False
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
...@@ -477,6 +486,7 @@ async def parse_args(args: list[str]) -> Config: ...@@ -477,6 +486,7 @@ async def parse_args(args: list[str]) -> Config:
multimodal_worker=parsed_args.multimodal_worker, multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker, embedding_worker=parsed_args.embedding_worker,
dump_config_to=parsed_args.dump_config_to, dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
) )
logging.debug(f"Dynamo args: {dynamo_args}") logging.debug(f"Dynamo args: {dynamo_args}")
......
...@@ -74,6 +74,7 @@ class DynamoSglangPublisher: ...@@ -74,6 +74,7 @@ class DynamoSglangPublisher:
""" """
self.engine = engine self.engine = engine
self.server_args = config.server_args self.server_args = config.server_args
self.dynamo_args = config.dynamo_args
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
self.component = component self.component = component
self.metrics_publisher = WorkerMetricsPublisher() self.metrics_publisher = WorkerMetricsPublisher()
...@@ -151,6 +152,7 @@ class DynamoSglangPublisher: ...@@ -151,6 +152,7 @@ class DynamoSglangPublisher:
worker_id=self.generate_endpoint.connection_id(), worker_id=self.generate_endpoint.connection_id(),
kv_block_size=self.server_args.page_size, kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep, zmq_endpoint=zmq_ep,
enable_local_indexer=self.dynamo_args.enable_local_indexer,
) )
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}") logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
self.kv_publisher = ZmqKvEventPublisher( self.kv_publisher = ZmqKvEventPublisher(
......
...@@ -82,6 +82,7 @@ async def _get_runtime_config( ...@@ -82,6 +82,7 @@ async def _get_runtime_config(
# set reasoning parser and tool call parser # set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.reasoning_parser runtime_config.reasoning_parser = dynamo_args.reasoning_parser
runtime_config.tool_call_parser = dynamo_args.tool_call_parser runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# In SGLang, these are server_args, not scheduler_info (unlike vLLM) # In SGLang, these are server_args, not scheduler_info (unlike vLLM)
# Note: If --max-running-requests is not specified, SGLang uses an internal default # Note: If --max-running-requests is not specified, SGLang uses an internal default
......
...@@ -351,6 +351,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -351,6 +351,7 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.max_num_batched_tokens = config.max_num_tokens runtime_config.max_num_batched_tokens = config.max_num_tokens
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.enable_local_indexer = config.enable_local_indexer
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}") logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
logging.info( logging.info(
...@@ -458,6 +459,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -458,6 +459,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.kv_block_size, config.kv_block_size,
metrics_labels, metrics_labels,
zmq_endpoint=trtllm_zmq_bind_endpoint, zmq_endpoint=trtllm_zmq_bind_endpoint,
enable_local_indexer=config.enable_local_indexer,
) as publisher: ) as publisher:
handler_config.publisher = publisher handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config) handler = RequestHandlerFactory().get_request_handler(handler_config)
......
...@@ -276,6 +276,7 @@ class Publisher: ...@@ -276,6 +276,7 @@ class Publisher:
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
zmq_endpoint: Optional[str] = None, zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
): ):
self.component = component self.component = component
self.engine = engine self.engine = engine
...@@ -284,6 +285,7 @@ class Publisher: ...@@ -284,6 +285,7 @@ class Publisher:
self.kv_block_size = kv_block_size self.kv_block_size = kv_block_size
self.max_window_size = None self.max_window_size = None
self.metrics_labels = metrics_labels self.metrics_labels = metrics_labels
self.enable_local_indexer = enable_local_indexer
# The first few kv events from the model engine are always "created" type events. # The first few kv events from the model engine are always "created" type events.
# Use these events to capture the max_window_size of the model. # Use these events to capture the max_window_size of the model.
...@@ -348,7 +350,11 @@ class Publisher: ...@@ -348,7 +350,11 @@ class Publisher:
else: else:
# No consolidator: use NATS publisher (router subscribes directly) # No consolidator: use NATS publisher (router subscribes directly)
self.kv_event_publisher = KvEventPublisher( self.kv_event_publisher = KvEventPublisher(
self.kv_listener, self.worker_id, self.kv_block_size, dp_rank=0 self.kv_listener,
self.worker_id,
self.kv_block_size,
dp_rank=0,
enable_local_indexer=self.enable_local_indexer,
) )
# Always initialize the thread - it routes to either ZMQ or NATS publisher # Always initialize the thread - it routes to either ZMQ or NATS publisher
...@@ -685,6 +691,7 @@ async def get_publisher( ...@@ -685,6 +691,7 @@ async def get_publisher(
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
zmq_endpoint: Optional[str] = None, zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
): ):
publisher = Publisher( publisher = Publisher(
component, component,
...@@ -694,6 +701,7 @@ async def get_publisher( ...@@ -694,6 +701,7 @@ async def get_publisher(
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
zmq_endpoint=zmq_endpoint, zmq_endpoint=zmq_endpoint,
enable_local_indexer=enable_local_indexer,
) )
try: try:
publisher.initialize() publisher.initialize()
......
...@@ -61,6 +61,7 @@ class Config: ...@@ -61,6 +61,7 @@ class Config:
self.dyn_endpoint_types: str = "chat,completions" self.dyn_endpoint_types: str = "chat,completions"
self.store_kv: str = "" self.store_kv: str = ""
self.request_plane: str = "" self.request_plane: str = ""
self.enable_local_indexer: bool = False
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -93,7 +94,8 @@ class Config: ...@@ -93,7 +94,8 @@ class Config:
f"dump_config_to={self.dump_config_to}, " f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}, " f"custom_jinja_template={self.custom_jinja_template}, "
f"store_kv={self.store_kv}, " f"store_kv={self.store_kv}, "
f"request_plane={self.request_plane}" f"request_plane={self.request_plane}, "
f"enable_local_indexer={self.enable_local_indexer}"
) )
...@@ -303,6 +305,13 @@ def cmd_line_args(): ...@@ -303,6 +305,13 @@ def cmd_line_args():
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument(
"--enable-local-indexer",
type=str,
choices=["true", "false"],
default=os.environ.get("DYN_LOCAL_INDEXER", "false"),
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -365,6 +374,7 @@ def cmd_line_args(): ...@@ -365,6 +374,7 @@ def cmd_line_args():
config.dyn_endpoint_types = args.dyn_endpoint_types config.dyn_endpoint_types = args.dyn_endpoint_types
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
# Handle custom jinja template path expansion (environment variables and home directory) # Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template: if args.custom_jinja_template:
......
...@@ -245,12 +245,13 @@ pub(crate) struct KvEventPublisher { ...@@ -245,12 +245,13 @@ pub(crate) struct KvEventPublisher {
#[pymethods] #[pymethods]
impl KvEventPublisher { impl KvEventPublisher {
#[new] #[new]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))] #[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0, enable_local_indexer=false))]
fn new( fn new(
component: Component, component: Component,
worker_id: WorkerId, worker_id: WorkerId,
kv_block_size: usize, kv_block_size: usize,
dp_rank: DpRank, dp_rank: DpRank,
enable_local_indexer: bool,
) -> PyResult<Self> { ) -> PyResult<Self> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
...@@ -260,10 +261,11 @@ impl KvEventPublisher { ...@@ -260,10 +261,11 @@ impl KvEventPublisher {
// The actual worker_id is inferred from component's connection_id in the Rust implementation. // The actual worker_id is inferred from component's connection_id in the Rust implementation.
let _ = worker_id; let _ = worker_id;
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
component.inner, component.inner,
kv_block_size as u32, kv_block_size as u32,
None, None,
enable_local_indexer,
) )
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -794,7 +794,7 @@ class KvEventPublisher: ...@@ -794,7 +794,7 @@ class KvEventPublisher:
... ...
def __init__( def __init__(
self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0 self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0, enable_local_indexer: bool = False
) -> None: ) -> None:
""" """
Create a `KvEventPublisher` object Create a `KvEventPublisher` object
...@@ -804,6 +804,7 @@ class KvEventPublisher: ...@@ -804,6 +804,7 @@ class KvEventPublisher:
worker_id: The worker ID worker_id: The worker ID
kv_block_size: The KV block size (must be > 0) kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0) dp_rank: The data parallel rank (defaults to 0)
enable_local_indexer: Enable worker-local KV indexer (defaults to False)
""" """
def publish_stored( def publish_stored(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment