"launch/vscode:/vscode.git/clone" did not exist on "da83f820ebe7e2f353c559d18fba2d3ec4ce01a3"
Unverified Commit 80d8aa19 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sglang): unify entry point for SGLang backend architecture (#2493)

parent 28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Base handlers
from .base_handlers import BaseWorkerHandler
# Protocol types
from .protocol import (
DisaggPreprocessedRequest,
PreprocessedRequest,
SamplingOptions,
StopConditions,
TokenIdType,
)
# Utilities
from .sgl_utils import (
graceful_shutdown,
parse_sglang_args_inc,
reserve_free_port,
setup_native_endpoints,
)
__all__ = [
# Protocol types
"DisaggPreprocessedRequest",
"PreprocessedRequest",
"SamplingOptions",
"StopConditions",
"TokenIdType",
# Utilities
"parse_sglang_args_inc",
"reserve_free_port",
"graceful_shutdown",
"setup_native_endpoints",
# Base handlers
"BaseWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
class BaseWorkerHandler(ABC):
"""
Abstract base class for sglang request handlers. We use this to implement native sglang endpoints for
workers
"""
@abstractmethod
def __init__(
self,
engine: sgl.Engine,
server_args: ServerArgs,
component,
decode_client: Optional[Any] = None,
):
self.engine = engine
self.server_args = server_args
self.component = component
@abstractmethod
async def generate(self, request):
"""Generate tokens from the engine"""
...
async def flush_cache(self, request: dict):
"""Flush KV cache for each worker"""
_ = request
await self.engine.tokenizer_manager.flush_cache()
yield True
async def start_expert_distribution_record(self, request: dict):
"""
Start recording expert distribution.
"""
_ = request
await self.engine.tokenizer_manager.start_expert_distribution_record()
yield True
async def stop_expert_distribution_record(self, request: dict):
"""
Stop recording expert distribution.
"""
_ = request
await self.engine.tokenizer_manager.stop_expert_distribution_record()
yield True
async def dump_expert_distribution_record(self, request: dict):
"""
Dumps the expert distribution record to the directory specified in the environment variable `SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR`.
"""
_ = request
await self.engine.tokenizer_manager.dump_expert_distribution_record()
yield True
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import contextlib
import logging
import socket
from argparse import Namespace
from sglang.srt.server_args import ServerArgs
class SkipTokenizerInitError(RuntimeError):
def __init__(self):
super().__init__("--skip-tokenizer-init flag is required")
def parse_sglang_args_inc(args: list[str]) -> ServerArgs:
# Currently we only support Dynamo doing the tokenization, so we must give
# sglang the skip-tokenizer-init flag. We don't default it because this is temporary.
# Allow the --version and --help flags through.
temp_need_tok = ["--skip-tokenizer-init", "--version", "--help", "-h"]
if not any(w in args for w in temp_need_tok):
raise SkipTokenizerInitError()
parser = argparse.ArgumentParser()
bootstrap_port = _reserve_disaggregation_bootstrap_port()
ServerArgs.add_cli_args(parser)
parsed_args = parser.parse_args(args)
if not any(arg.startswith("--disaggregation-bootstrap-port") for arg in args):
args_dict = vars(parsed_args)
args_dict["disaggregation_bootstrap_port"] = bootstrap_port
parsed_args = Namespace(**args_dict)
return ServerArgs.from_cli_args(parsed_args)
@contextlib.contextmanager
def reserve_free_port(host: str = "localhost"):
"""
Find and reserve a free port until context exits.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind((host, 0))
_, port = sock.getsockname()
yield port
finally:
sock.close()
def _reserve_disaggregation_bootstrap_port():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
We use an existing utility function that reserves a free port on your
machine to avoid collisions.
"""
with reserve_free_port() as port:
return port
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def setup_native_endpoints(server_args, component, handler):
"""Setup sgl native endpoints"""
# flush cache
flush_endpoint = component.endpoint("flush_cache")
tasks = []
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
# expert distribution endpoints
if server_args.expert_distribution_recorder_mode is not None:
start_expert_distribution_endpoint = component.endpoint(
"start_expert_distribution_record"
)
stop_expert_distribution_endpoint = component.endpoint(
"stop_expert_distribution_record"
)
dump_expert_distribution_endpoint = component.endpoint(
"dump_expert_distribution_record"
)
tasks.append(
start_expert_distribution_endpoint.serve_endpoint(
handler.start_expert_distribution_record
)
)
tasks.append(
stop_expert_distribution_endpoint.serve_endpoint(
handler.stop_expert_distribution_record
)
)
tasks.append(
dump_expert_distribution_endpoint.serve_endpoint(
handler.dump_expert_distribution_record
)
)
return tasks
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
...@@ -2,7 +2,16 @@ ...@@ -2,7 +2,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dynamo.sglang.decode_worker.main import main import logging
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.main import main
if __name__ == "__main__": if __name__ == "__main__":
configure_dynamo_logging()
logging.warning(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.decode_worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead.",
)
main() main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import logging
import signal
import sys
import msgspec
import sglang as sgl
import uvloop
from sglang.srt.server_args import ServerArgs
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.common import (
BaseWorkerHandler,
graceful_shutdown,
parse_sglang_args_inc,
setup_native_endpoints,
)
configure_dynamo_logging()
class DecodeRequestHandler(BaseWorkerHandler):
def __init__(self, engine: sgl.Engine, server_args: ServerArgs, component):
super().__init__(engine, server_args, component)
logging.info("Decode request handler initialized")
async def generate(self, request: str):
req = msgspec.json.decode(request, type=dict)
results = await self.engine.async_generate(
input_ids=req["request"]["token_ids"]
if req["request"]["batch_token_ids"] is None
else req["request"]["batch_token_ids"],
sampling_params=req["sampling_params"],
stream=True,
bootstrap_host=req["bootstrap_host"],
bootstrap_port=req["bootstrap_port"],
bootstrap_room=req["bootstrap_room"],
)
async for result in results:
yield result
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
server_args = parse_sglang_args_inc(sys.argv[1:])
await init(runtime, server_args)
async def init(runtime: DistributedRuntime, server_args: ServerArgs):
"""Initialize decode worker"""
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace("dynamo").component("decode")
await component.create_service()
handler = DecodeRequestHandler(engine, server_args, component)
gen_endpoint = component.endpoint("generate")
tasks = [gen_endpoint.serve_endpoint(handler.generate)]
tasks.extend(setup_native_endpoints(server_args, component, handler))
await asyncio.gather(*tasks)
def main():
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import signal
import sys
import sglang as sgl
import uvloop
from sglang.srt.utils import get_ip
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler
configure_dynamo_logging()
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = parse_args(sys.argv[1:])
if config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config)
else:
await init_prefill(runtime, config)
async def init(runtime: DistributedRuntime, config: Config):
server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# TODO: think about implementing DisaggregationStrategy for P->D
# TODO: implement a `next` field in the config to dynamically set the next client
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
publisher, metrics_task = await setup_sgl_metrics(engine, component)
kv_publisher = None
if server_args.kv_events_config:
kv_events = json.loads(server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=server_args.page_size,
zmq_endpoint=zmq_ep,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
handler = DecodeWorkerHandler(
component, engine, config, publisher, kv_publisher, prefill_client
)
await register_llm_with_runtime_config(
engine, generate_endpoint, server_args, dynamo_args.migration_limit
)
try:
# TODO: add in native endpoints
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task succesfully cancelled")
pass
handler.cleanup()
async def init_prefill(runtime: DistributedRuntime, config: Config):
server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = PrefillWorkerHandler(component, engine, config)
tasks = [generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True)]
try:
await asyncio.gather(*tasks)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def main():
uvloop.run(worker())
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -58,7 +46,4 @@ class PreprocessedRequest(BaseModel): ...@@ -58,7 +46,4 @@ class PreprocessedRequest(BaseModel):
class DisaggPreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel):
request: PreprocessedRequest request: PreprocessedRequest
sampling_params: dict sampling_params: dict
bootstrap_host: Union[str, List[str]]
bootstrap_port: Union[int, List[int]]
bootstrap_room: Union[int, List[int]]
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import Optional
import sglang as sgl
import zmq
import zmq.asyncio
from sglang.srt.utils import get_zmq_socket
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
SpecDecodeStats,
WorkerMetricsPublisher,
WorkerStats,
)
from dynamo.runtime import Component
class DynamoSglangStatPublisher:
"""
Handles SGLang metrics reception and publishing.
"""
def __init__(self, engine: sgl.Engine, component: Component) -> None:
self.engine = engine
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
# Set default values (can be overridden later if needed)
self.request_total_slots = 1024
self.dp_rank = 0
self.num_gpu_block = 1024
# ZMQ setup for receiving scheduler metrics
self._ctx = zmq.asyncio.Context() # type: ignore
self._sock = get_zmq_socket(
self._ctx, zmq.PULL, self.engine.port_args.metrics_ipc_name, True # type: ignore
)
async def run(self) -> None:
"""Main loop to receive scheduler metrics and publish them"""
while True:
try:
kv_metrics = await self._sock.recv_pyobj() # type: ignore
self.record_values(
request_active_slots=kv_metrics.request_active_slots,
request_total_slots=kv_metrics.request_total_slots,
kv_active_blocks=kv_metrics.kv_active_blocks,
kv_total_blocks=kv_metrics.kv_total_blocks,
num_requests_waiting=kv_metrics.num_requests_waiting,
gpu_cache_usage_perc=kv_metrics.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate=kv_metrics.gpu_prefix_cache_hit_rate,
data_parallel_rank=kv_metrics.data_parallel_rank,
)
except Exception:
logging.exception(
"Failed to receive or publish SGLang scheduler metrics"
)
def init_publish(self) -> None:
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=self.request_total_slots,
num_requests_waiting=0,
data_parallel_rank=self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=self.num_gpu_block,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=None,
)
logging.info("Sending dummy metrics to initialize")
self.inner.publish(metrics)
def record(
self,
worker_stats: WorkerStats,
kv_stats: KvStats,
spec_decode_stats: Optional[SpecDecodeStats] = None,
) -> None:
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_decode_stats,
)
self.inner.publish(metrics)
def record_values(
self,
request_active_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_total_blocks: int,
num_requests_waiting: int,
gpu_cache_usage_perc: float,
gpu_prefix_cache_hit_rate: float,
data_parallel_rank: Optional[int] = None,
spec_decode_stats: Optional[SpecDecodeStats] = None,
) -> None:
worker_stats = WorkerStats(
request_active_slots=request_active_slots,
request_total_slots=request_total_slots,
num_requests_waiting=num_requests_waiting,
data_parallel_rank=data_parallel_rank
if data_parallel_rank is not None
else self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=kv_active_blocks,
kv_total_blocks=kv_total_blocks,
gpu_cache_usage_perc=gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
)
self.record(worker_stats, kv_stats, spec_decode_stats)
async def setup_sgl_metrics(
engine: sgl.Engine,
component: Component,
) -> tuple[DynamoSglangStatPublisher, asyncio.Task]:
"""
Convenience bootstrap: create endpoint, publish an initial update, and start the metrics loop.
"""
publisher = DynamoSglangStatPublisher(engine, component)
publisher.init_publish()
task = asyncio.create_task(publisher.run())
logging.info("SGLang metrics loop started")
return publisher, task
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
async def register_llm_with_runtime_config(
engine: sgl.Engine,
endpoint: Endpoint,
server_args: ServerArgs,
migration_limit: int,
):
"""Register LLM with runtime config"""
runtime_config = await _get_runtime_config(engine)
try:
await register_llm(
ModelType.Backend,
endpoint,
server_args.model_path,
server_args.served_model_name,
kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
runtime_config=runtime_config,
)
except Exception as e:
logging.error(f"Failed to register with runtime config: {e}")
return None
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine"""
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
runtime_config = ModelRuntimeConfig()
# Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info:
max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
if max_total_tokens and hasattr(
engine.tokenizer_manager, "server_args"
):
page_size = engine.tokenizer_manager.server_args.page_size
if page_size:
runtime_config.total_kv_blocks = (
max_total_tokens + page_size - 1
) // page_size
logging.info(
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
return runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging.warning(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return None
except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler
# Base handlers
from .handler_base import BaseWorkerHandler
from .prefill_handler import PrefillWorkerHandler
__all__ = [
"BaseWorkerHandler",
"DecodeWorkerHandler",
"PrefillWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import sglang as sgl
from dynamo._core import Client, Component
from dynamo.llm import WorkerMetricsPublisher, ZmqKvEventPublisher
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class DecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
metrics_publisher: WorkerMetricsPublisher,
kv_publisher: ZmqKvEventPublisher = None,
prefill_client: Client = None,
):
super().__init__(
component, engine, config, metrics_publisher, kv_publisher, prefill_client
)
if self.serving_mode == DisaggregationMode.DECODE:
if self.prefill_client is None:
raise ValueError(
"prefill_client must be provided when serving_mode is decode"
)
self.prefill_client = prefill_client
logging.info("Decode worker handler initialized")
logging.info("Worker handler initialized")
def cleanup(self):
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
def _build_sampling_params(self, request: dict) -> dict:
sampling_params = {}
if request["sampling_options"]["temperature"]:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
if request["sampling_options"]["top_p"]:
sampling_params["top_p"] = request["sampling_options"]["top_p"]
if request["sampling_options"]["top_k"]:
sampling_params["top_k"] = request["sampling_options"]["top_k"]
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
if request["stop_conditions"]["ignore_eos"]:
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params
async def generate(self, request: str):
sampling_params = self._build_sampling_params(request)
if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump_json()
)
bootstrap_info = None
async for info in prefill_stream:
bootstrap_info = info.data()
break
if not bootstrap_info:
raise RuntimeError("No bootstrap info received from prefill worker")
decode = await self.engine.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
)
async for out in self._process_stream(decode):
yield out
else:
agg = await self.engine.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for out in self._process_stream(agg):
yield out
async def _process_stream(self, stream_source):
num_output_tokens_so_far = 0
async for res in stream_source:
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
next_total_toks = len(res["output_ids"])
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks
yield out
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
import sglang as sgl
from dynamo._core import Client, Component
from dynamo.llm import WorkerMetricsPublisher, ZmqKvEventPublisher
from dynamo.sglang.args import Config
class BaseWorkerHandler(ABC):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
metrics_publisher: WorkerMetricsPublisher = None,
kv_publisher: ZmqKvEventPublisher = None,
prefill_client: Client = None,
):
self.component = component
self.engine = engine
self.config = config
self.metrics_publisher = metrics_publisher
self.kv_publisher = kv_publisher
self.prefill_client = prefill_client
self.serving_mode = config.serving_mode
@abstractmethod
async def generate(self, request: str):
pass
def cleanup(self):
pass
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import random
import socket
import msgspec
import sglang as sgl
from sglang.srt.utils import get_ip
from dynamo._core import Component
from dynamo.sglang.args import Config
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component: Component, engine: sgl.Engine, config: Config):
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
super().__init__(component, engine, config, None, None, None)
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
def _generate_bootstrap_room(self):
return random.randint(0, 2**63 - 1)
def cleanup(self):
self.engine.shutdown()
logging.info("Prefill engine shutdown")
super().cleanup()
def _get_bootstrap_info(self):
"""Bootstrap info from tokenizer manager"""
inner_tm = self.engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
return bootstrap_host, bootstrap_port
async def generate(self, request: str):
req = msgspec.json.decode(request, type=dict)
bootstrap_room = self._generate_bootstrap_room()
bootstrap_info = {
"bootstrap_host": self.bootstrap_host,
"bootstrap_port": self.bootstrap_port,
"bootstrap_room": bootstrap_room,
}
yield bootstrap_info
results = await self.engine.async_generate(
input_ids=req["request"]["token_ids"],
sampling_params=req["sampling_params"],
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
asyncio.create_task(self._consume_results(results))
async def _consume_results(self, results):
async for _ in results:
pass
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import argparse import argparse
import asyncio import asyncio
...@@ -23,7 +9,6 @@ from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker ...@@ -23,7 +9,6 @@ from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__)
@dynamo_worker() @dynamo_worker()
...@@ -34,7 +19,7 @@ async def clear_namespace(runtime: DistributedRuntime, namespace: str): ...@@ -34,7 +19,7 @@ async def clear_namespace(runtime: DistributedRuntime, namespace: str):
{}, {},
) )
await etcd_kv_cache.clear_all() await etcd_kv_cache.clear_all()
logger.info(f"Cleared /{namespace} in EtcdKvCache") logging.info(f"Cleared /{namespace} in EtcdKvCache")
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi.routing import APIRoute
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
FLUSH_CACHE_ENDPOINT = "flush_cache"
configure_dynamo_logging()
class SglangHttpServer:
def __init__(self, port: int, runtime: DistributedRuntime, args):
self.port = port
self.app = FastAPI()
self.runtime = runtime
self.args = args
self.setup_routes()
async def _discover_endpoints(self, endpoint_name):
"""Discover endpoints that match the pattern"""
etcd_client = self.runtime.etcd_client()
if etcd_client is None:
raise RuntimeError("Runtime has no etcd client; cannot discover endpoints")
prefix = "instances/"
kvs = await etcd_client.kv_get_prefix(prefix)
# Collect (namespace, component) combos that expose the target endpoint
discovered = set()
for kv in kvs:
key = kv["key"] if isinstance(kv, dict) else kv.key
if isinstance(key, bytes):
key = key.decode()
if not key.startswith(prefix):
continue
segments = key.split("/")
# Format: instances/<ns>/<comp>/<endpoint:lease>
if len(segments) < 4:
continue
ns, comp, ep_with_lease = segments[1], segments[2], segments[3]
if self.args.ns and ns != self.args.ns:
continue
if self.args.comp and comp != self.args.comp:
continue
ep_name = ep_with_lease.split(":", 1)[0]
if ep_name == endpoint_name:
discovered.add((ns, comp))
logging.debug(f"Discovered endpoint: {ns}.{comp}")
logging.debug(
f"Endpoint discovery complete. Found {len(discovered)} matching endpoints"
)
return discovered
async def _dispatch_command(
self, endpoint_name: str, payload: dict | str = "{}", success_message: str = ""
):
"""Dispatches a command to all instances of a discovered endpoint."""
discovered = await self._discover_endpoints(endpoint_name=endpoint_name)
if not discovered:
return {"message": "No matching endpoints found", "success": False}
logging.debug(
f"Found components: {', '.join([f'{ns}.{comp}' for ns, comp in discovered])}"
)
for ns, comp in discovered:
ep = self.runtime.namespace(ns).component(comp).endpoint(endpoint_name)
client = await ep.client()
await client.wait_for_instances()
ids = client.instance_ids()
logging.debug(f"-- {ns}.{comp} : {len(ids)} instances --")
for inst_id in ids:
try:
stream = await client.direct(payload, inst_id)
async for stream_payload in stream:
logging.debug(f"[{ns}.{comp}][{inst_id}] -> {stream_payload}")
except Exception as e:
logging.error(
f"[{ns}.{comp}][{inst_id}] {endpoint_name} error: {e}"
)
return {"message": success_message, "success": True}
def setup_routes(self):
@self.app.post("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
endpoint_name = self.args.endpoint
try:
return await self._dispatch_command(
endpoint_name,
success_message="Cache flush initiated",
)
except Exception as e:
logging.error(f"Cache flush error: {e}")
return {"message": f"Cache flush failed: {str(e)}", "success": False}
@self.app.post("/start_expert_distribution_record")
async def start_expert_distribution_record():
"""Start recording expert distribution."""
endpoint_name = "start_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording started",
)
except Exception as e:
logging.error(f"Start expert distribution error: {e}")
return {
"message": f"Start expert distribution failed: {str(e)}",
"success": False,
}
@self.app.post("/stop_expert_distribution_record")
async def stop_expert_distribution_record():
"""Stop recording expert distribution."""
endpoint_name = "stop_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording stopped",
)
except Exception as e:
logging.error(f"Stop expert distribution error: {e}")
return {
"message": f"Stop expert distribution failed: {str(e)}",
"success": False,
}
@self.app.post("/dump_expert_distribution_record")
async def dump_expert_distribution_record(request: dict):
"""Dump expert distribution recording to specified directory."""
endpoint_name = "dump_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording dumped to directory",
)
except Exception as e:
logging.error(f"Dump expert distribution error: {e}")
return {
"message": f"Dump expert distribution failed: {str(e)}",
"success": False,
}
async def start_server(self):
"""Start the HTTP server"""
config = uvicorn.Config(
self.app,
host="0.0.0.0",
port=self.port,
)
server = uvicorn.Server(config)
# Debug: print all registered routes
for route in self.app.routes:
if isinstance(route, APIRoute):
logging.debug(f"Registered route: {route.methods} {route.path}")
await server.serve()
def parse_args():
p = argparse.ArgumentParser(description="SGLang HTTP server for cache management")
p.add_argument("--port", type=int, default=9001, help="Port to listen on")
p.add_argument(
"--ns",
"--namespace",
default="dynamo",
help="Specify Dynamo namespace (default: discover all)",
)
p.add_argument(
"--comp",
"--component",
default=None,
help="Specify component name (default: discover all)",
)
return p.parse_args()
@dynamo_worker(static=False)
async def main(runtime: DistributedRuntime):
args = parse_args()
http_server = SglangHttpServer(args.port, runtime, args)
await http_server.start_server()
if __name__ == "__main__":
uvloop.install()
asyncio.run(main())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dynamo.sglang.worker.main import main
import logging
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.main import main
if __name__ == "__main__": if __name__ == "__main__":
configure_dynamo_logging()
logging.warning(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead.",
)
main() main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import random
import signal
import socket
import sys
from typing import Any, Dict, Optional, Union
import sglang as sgl
import uvloop
import zmq
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_ip, get_zmq_socket
from dynamo._core import Endpoint
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelRuntimeConfig,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.common import (
BaseWorkerHandler,
DisaggPreprocessedRequest,
graceful_shutdown,
parse_sglang_args_inc,
setup_native_endpoints,
)
configure_dynamo_logging()
class RequestHandler(BaseWorkerHandler):
def __init__(
self,
engine: sgl.Engine,
server_args: ServerArgs,
component,
decode_client: Optional[Any] = None,
):
super().__init__(engine, server_args, component, decode_client)
self.metrics_publisher = WorkerMetricsPublisher()
self.zmq_context = zmq.asyncio.Context() # type: ignore
self.receive_metrics_from_scheduler = None
if server_args.disaggregation_mode != "null":
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
if decode_client is None:
raise ValueError(
"decode_client must be provided when disaggregation_mode is not 'null'"
)
self.decode_client = decode_client
logging.info(
f"Disaggregation enabled - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
logging.info("Request handler initialized")
def setup_metrics(self):
"""Set up metrics publisher"""
self.receive_metrics_from_scheduler = get_zmq_socket(
self.zmq_context, zmq.PULL, self.engine.port_args.metrics_ipc_name, True
)
self.init_publish()
asyncio.create_task(self._receive_and_publish_metrics_loop())
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
def init_publish(self):
"""Publish initial set of warmup metrics"""
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=0,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0,
gpu_prefix_cache_hit_rate=0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=None,
)
self.metrics_publisher.publish(metrics)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def _receive_and_publish_metrics_loop(self):
"""Receive metrics from SGL scheduler and publish them"""
while True:
try:
kv_metrics = await self.receive_metrics_from_scheduler.recv_pyobj() # type: ignore
worker_stats = WorkerStats(
request_active_slots=kv_metrics.request_active_slots,
request_total_slots=kv_metrics.request_total_slots,
num_requests_waiting=kv_metrics.num_requests_waiting,
data_parallel_rank=kv_metrics.data_parallel_rank, # Note: 0 means it's either 0 or None from sglang
)
kv_stats = KvStats(
kv_active_blocks=kv_metrics.kv_active_blocks,
kv_total_blocks=kv_metrics.kv_total_blocks,
gpu_cache_usage_perc=kv_metrics.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate=kv_metrics.gpu_prefix_cache_hit_rate,
)
spec_dec_stats = None
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_dec_stats,
)
self.metrics_publisher.publish(metrics)
except Exception:
logging.exception("Failed to recieve or publish metrics")
def _get_bootstrap_info(self):
"""Bootstrap info from tokenizer manager"""
inner_tm = self.engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
return bootstrap_host, bootstrap_port
def _build_sampling_params(self, request: dict) -> dict:
sampling_params = {}
if request["sampling_options"]["temperature"]:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
if request["sampling_options"]["top_p"]:
sampling_params["top_p"] = request["sampling_options"]["top_p"]
if request["sampling_options"]["top_k"]:
sampling_params["top_k"] = request["sampling_options"]["top_k"]
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
if request["stop_conditions"]["ignore_eos"]:
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params
def _get_request_batch_size(self, request: dict):
"""Get batch size from request, returns None for single requests"""
if request["batch_token_ids"] is not None:
return len(request["batch_token_ids"])
return None
def _is_batch_request(self, request: dict):
"""Check if request is in batch mode"""
return request["batch_token_ids"] is not None
def _generate_bootstrap_room(self):
return random.randint(0, 2**63 - 1)
async def generate(self, request: dict):
is_batch = self._is_batch_request(request)
batch_size = self._get_request_batch_size(request)
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params = self._build_sampling_params(request)
if self.server_args.disaggregation_mode != "null":
if is_batch:
bootstrap_room = [
self._generate_bootstrap_room() for _ in range(batch_size)
]
bootstrap_host = [self.bootstrap_host] * batch_size
bootstrap_port = [self.bootstrap_port] * batch_size
else:
bootstrap_host = self.bootstrap_host
bootstrap_port = self.bootstrap_port
bootstrap_room = self._generate_bootstrap_room()
# decode worker request
disagg_request = DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
)
# prefill response is not used
prefill = await self.engine.async_generate(
input_ids=request["token_ids"]
if not is_batch
else request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
)
prefill_task = asyncio.create_task(self._prefill_generator(prefill))
decode = await self.decode_client.generate(disagg_request.model_dump_json())
async for out in self._process_stream(
decode, unpack=True, is_batch=is_batch
):
yield out
await prefill_task
else:
g = await self.engine.async_generate(
input_ids=request["token_ids"]
if not is_batch
else request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
yield out
async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
# Initialize based on batch mode
num_output_tokens_so_far: Union[Dict[int, int], int]
if is_batch:
num_output_tokens_so_far = {}
else:
num_output_tokens_so_far = 0
async for res in stream_source:
data = res.data() if unpack else res
finish_reason = data["meta_info"]["finish_reason"]
if is_batch:
# Handle batch response
assert isinstance(num_output_tokens_so_far, dict)
index = data.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
next_total_toks = len(data["output_ids"])
new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else:
# Handle single response
assert isinstance(num_output_tokens_so_far, int)
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
next_total_toks = len(data["output_ids"])
out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks
yield out
async def _prefill_generator(self, prefill):
async for _ in prefill:
pass
async def flush_cache(self, request: dict):
_ = request
asyncio.create_task(self.engine.tokenizer_manager.flush_cache())
yield {
"status": "success",
"message": "Cache flush initiated. Check backend logs for status",
}
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# TODO: Better handle non-sglang args
sys_argv = sys.argv[1:]
migration_limit = 0
try:
idx = sys_argv.index("--migration-limit")
migration_limit = int(sys_argv[idx + 1])
del sys_argv[idx : idx + 2] # Remove the args from sys_argv
except Exception:
pass
server_args = parse_sglang_args_inc(sys_argv)
await init(runtime, server_args, migration_limit)
async def init(
runtime: DistributedRuntime, server_args: ServerArgs, migration_limit: int
):
"""Initialize worker (either prefill or aggregated)"""
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace("dynamo").component("worker")
await component.create_service()
endpoint = component.endpoint("generate")
await register_llm_with_runtime_config(
engine, endpoint, server_args, migration_limit
)
if server_args.disaggregation_mode != "null":
decode_client = (
await runtime.namespace("dynamo")
.component("decode")
.endpoint("generate")
.client()
)
handler = RequestHandler(engine, server_args, component, decode_client)
else:
handler = RequestHandler(engine, server_args, component)
# Set up the engine metrics reciever
handler.setup_metrics()
# Set up ZMQ kv event publisher
if server_args.kv_events_config:
kv_events = json.loads(server_args.kv_events_config)
ep = kv_events.get("endpoint")
zmq_ep = ep.replace("*", get_ip()) if ep else None
zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.lease_id(),
kv_block_size=server_args.page_size,
zmq_endpoint=zmq_ep,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
tasks = [endpoint.serve_endpoint(handler.generate)]
tasks.extend(setup_native_endpoints(server_args, component, handler))
await asyncio.gather(*tasks)
async def register_llm_with_runtime_config(
engine: sgl.Engine,
endpoint: Endpoint,
server_args: ServerArgs,
migration_limit: int,
):
"""Register LLM with runtime config"""
runtime_config = await _get_runtime_config(engine)
try:
await register_llm(
ModelType.Backend,
endpoint,
server_args.model_path,
server_args.served_model_name,
kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
runtime_config=runtime_config,
)
except Exception as e:
logging.error(f"Failed to register with runtime config: {e}")
return None
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine"""
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
runtime_config = ModelRuntimeConfig()
# Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info:
max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
if max_total_tokens and hasattr(
engine.tokenizer_manager, "server_args"
):
page_size = engine.tokenizer_manager.server_args.page_size
if page_size:
runtime_config.total_kv_blocks = (
max_total_tokens + page_size - 1
) // page_size
logging.info(
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
# TODO: figure out where they are
return runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging.warning(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return None
except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return None
def main():
uvloop.install()
asyncio.run(worker())
if __name__ == "__main__":
main()
...@@ -27,7 +27,7 @@ ARG ARCH=amd64 ...@@ -27,7 +27,7 @@ ARG ARCH=amd64
ARG ARCH_ALT=x86_64 ARG ARCH_ALT=x86_64
# Make sure to update the dependency version in pyproject.toml when updating this # Make sure to update the dependency version in pyproject.toml when updating this
ARG SGLANG_VERSION="0.4.9.post6" ARG SGLANG_VERSION="0.5.0rc2"
################################## ##################################
########## Base Image ############ ########## Base Image ############
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment