Unverified Commit 3bf22bb4 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: reorganize sglang and add expert distribution endpoints (#2181)

parent f10aab3b
......@@ -139,6 +139,8 @@ cd $DYNAMO_ROOT/components/backends/sglang
./launch/disagg_dp_attn.sh
```
When using MoE models, you can also use the our implementation of the native SGLang endpoints to record expert distribution data. The `disagg_dp_attn.sh` script automatically sets up the SGLang HTTP server, the environment variable that controls the expert distribution recording directory, and sets up the expert distribution recording mode to `stat`. You can learn more about expert parallelism load balancing [here](docs/expert-distribution-eplb.md).
## Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the Frontend and the Backend.
......@@ -166,12 +168,6 @@ Below we provide a selected list of advanced examples. Please open up an issue i
- **[Run DeepSeek-R1 on 104+ H100s](docs/dsr1-wideep-h100.md)**
- **[Run DeepSeek-R1 on GB200s](docs/dsr1-wideep-gb200.md)**
### Speculative Decoding
- **[Deploying DeepSeek-R1 with MTP - coming soon!](.)**
### Structured Output and Tool Calling
- **[Tool calling with Dynamo - coming soon!](.)**
### Supporting SGLang's native endpoints via Dynamo
- **[HTTP Server for native SGLang endpoints](docs/sgl-http-server.md)**
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# Expert Parallelism Load Balancer (EPLB) in SGLang
Mixture-of-Experts (MoE) models utilize a technique called Expert Parallelism (EP), where experts are distributed across multiple GPUs. While this allows for much larger and more powerful models, it can lead to an uneven workload distribution. Because the load on different experts may vary depending on the workload, some GPUs can become bottlenecks, forcing the entire system to wait. This imbalance leads to wasted compute cycles and increased memory usage.
To address this, SGLang implements an Expert Parallelism Load Balancer (EPLB) inspired by the work in the DeepSeek-V3 paper. EPLB analyzes expert usage patterns and dynamically re-arranges the experts across the available GPUs to ensure a more balanced workload.
## The EPLB Algorithm: Core Concepts
The load balancing algorithm revolves around a few key ideas to achieve an optimal distribution of work.
### Redundant Experts for Flexibility
The core strategy is to create **redundant experts**. Instead of being limited to the model's original number of experts, EPLB can create duplicates of heavily-loaded experts. For example, if a model has 256 experts, you can configure EPLB to create an additional 32 "redundant" experts, bringing the total to 288. This pool of replicated experts is then strategically packed onto the available GPUs. A popular expert might be duplicated multiple times, while a moderately used expert might be grouped with several rarely used ones on a single GPU.
### Group-Limited Routing for Efficiency
Modern MoE models like DeepSeek-V3 use **group-limited expert routing**. In this design, experts are organized into groups, and routing decisions are constrained within these groups. EPLB can take advantage of this structure to reduce inter-node data traffic by attempting to place all experts from the same group onto the same node whenever possible.
### Load Balancing Policies
The algorithm comes with two policies for different scenarios:
1. **Hierarchical Load Balancing**: This policy is used when the number of server nodes evenly divides the number of expert groups. It first harnesses the group-limited routing by packing expert groups onto nodes to balance the load between nodes. Then, within each node, it replicates and packs the experts onto individual GPUs to balance the load locally. This is often used during prefill where the expert-parallel size might be smaller.
2. **Global Load Balancing**: In all other cases, a global policy is used. It replicates experts globally without regard to their group affiliation and packs them onto individual GPUs. This policy is more general and can be adopted during the decoding stage with a larger expert-parallel size.
## How SGLang Implements EPLB
SGLang provides a robust implementation of EPLB, allowing for dynamic, online rebalancing of expert locations based on real-world traffic.
### Dynamic Rebalancing
You can enable dynamic rebalancing by setting the `--enable-eplb` flag. When enabled, the `EPLBManager` runs in the background. It periodically triggers a rebalance after a certain number of requests, configured with `--eplb-rebalance-num-iterations`. At each rebalance, it computes a new expert placement plan based on the latest usage statistics and updates the model's expert locations on the fly.
### Expert Usage Recording
To make intelligent balancing decisions, SGLang needs to collect data on expert usage. The `ExpertDistributionRecorder` is responsible for this, and its behavior is controlled by the `--expert-distribution-recorder-mode` flag. This flag determines the granularity of the collected data. When `enable_eplb` is on, this mode defaults to `stat` to gather statistics for rebalancing. The available modes are:
- **`per_token`**: This is the most detailed mode. It records the specific expert choices for every single token processed by the model. While it provides the richest data, it also has the highest performance overhead. The raw, unaggregated data for each forward pass is stored.
- **`per_pass`**: In this mode, SGLang records the aggregated expert usage counts for each individual forward pass. The data is not aggregated across different passes, giving you a snapshot of expert popularity for each batch of requests.
- **`stat`**: This mode also records the exact expert usage counts for each forward pass, but it then aggregates these counts across multiple passes (the number of passes is determined by `--expert-distribution-recorder-buffer-size`). This provides a moving average of expert usage statistics and is the default when EPLB is enabled.
- **`stat_approx`**: This mode is similar to `stat` but gathers _approximate_ statistics, usually from the DeepEP dispatcher. This method has lower overhead than `stat` but is less precise, especially for small batch sizes. It is a good choice when performance is critical.
The collected statistics are then fed into the rebalancing algorithm to generate a new expert placement plan.
### Initializing with a Pre-computed Distribution
While SGLang can start with a simple default layout and learn a better one over time, you can also provide it with a pre-computed expert distribution to start with. The `--init-expert-location` flag allows you to specify a file path (`.pt` or `.json`) or a JSON string containing an expert layout. This is useful if you have already analyzed a representative workload offline and want the server to start immediately with a balanced configuration. If this flag is not set, it defaults to a `trivial` sequential layout.
### References and further reading
- [SGLang Large Scale P/D + WideEP Deployment](https://lmsys.org/blog/2025-05-05-large-scale-ep/#expert-parallelism-load-balancer)
- [Deepseek's EPLB repository](https://github.com/deepseek-ai/EPLB)
......@@ -9,6 +9,9 @@ SPDX-License-Identifier: Apache-2.0
The SGLang HTTP server provides a REST API interface for managing and monitoring SGLang components running in a dynamo distributed environment. It leverages dynamo's service discovery mechanism to automatically find and communicate with SGLang workers across the cluster.
<details>
<summary>How it works under the hood</summary>
## Architecture Overview
The HTTP server (`sgl_http_server.py`) is built on FastAPI and integrates with dynamo's `DistributedRuntime` to discover and interact with SGLang components. It uses the following discovery flow:
......@@ -28,38 +31,27 @@ The server uses dynamo's hierarchical service discovery structure:
The discovery process queries etcd with the prefix `instances/` to find all registered components that expose the target endpoint. Components are identified by their namespace, component name, and endpoint, allowing the server to dynamically scale operations across multiple instances.
## Supported Endpoints
### Current Endpoints
</details>
#### POST /flush_cache
Flushes the radix cache across all discovered SGLang components.
## Supported Endpoints
**Behavior:**
- Discovers all components in the specified namespace that expose the `flush_cache` endpoint
- Sends flush requests to all instances of each discovered component
- Returns success/failure status with details about the operation
All of these endpoints can be called using
**Response:**
```json
{
"message": "Cache flush initiated",
"success": true
}
```bash
curl -X POST http://<ip>:9001/<endpoint>
```
### Upcoming Endpoints
The following endpoints will be supported in future releases:
#### `/flush_cache`
Flushes the kv cache across all SGLang components. Useful for resetting after a warmup or a benchmarking run.
#### POST /start_expert_distribution_record
#### `/start_expert_distribution_record`
Begins recording expert distribution metrics across SGLang components.
#### POST /stop_expert_distribution_record
#### `/stop_expert_distribution_record`
Stops the expert distribution recording process.
#### GET /dump_expert_distribution_record
Retrieves the collected expert distribution data.
#### `/dump_expert_distribution_record`
Dumps the collected expert distribution data.
## Configuration
......@@ -67,8 +59,6 @@ The server accepts the following command-line arguments:
- `--port`: HTTP server port (default: 9001)
- `--ns/--namespace`: Target dynamo namespace (default: "dynamo")
- `--comp/--component`: Specific component name to target (default: discover all)
- `--endpoint`: Endpoint name to discover (default: "flush_cache")
## Usage
......
......@@ -15,7 +15,7 @@ trap cleanup EXIT INT TERM
python3 -m dynamo.sglang.utils.clear_namespace --namespace dynamo
# run ingress
dynamo run in=http out=dyn --router-mode kv --http-port=8000 &
python -m dynamo.frontend --router-mode kv --http-port=8000 &
DYNAMO_PID=$!
# run worker
......
......@@ -5,8 +5,8 @@
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID 2>/dev/null || true
kill $DYNAMO_PID $PREFILL_PID $HTTP_SERVER_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $HTTP_SERVER_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
......@@ -18,6 +18,14 @@ python3 -m dynamo.sglang.utils.clear_namespace --namespace dynamo
python3 -m dynamo.frontend --http-port=8000 &
DYNAMO_PID=$!
# run http server
python3 -m dynamo.sglang.utils.sgl_http_server --namespace dynamo &
HTTP_SERVER_PID=$!
# Set the expert distribution recording directory
mkdir -p /tmp/sglang_expert_distribution_record
export SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR=/tmp/sglang_expert_distribution_record
# run prefill worker
python3 -m dynamo.sglang.worker \
--model-path silence09/DeepSeek-R1-Small-2layers \
......@@ -29,6 +37,7 @@ python3 -m dynamo.sglang.worker \
--skip-tokenizer-init \
--disaggregation-mode prefill \
--disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \
--port 30000 &
PREFILL_PID=$!
......@@ -43,4 +52,5 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang.decode_worker \
--skip-tokenizer-init \
--disaggregation-mode decode \
--disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \
--port 31000
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Base handlers
from .base_handlers import BaseWorkerHandler
# Protocol types
from .protocol import (
DisaggPreprocessedRequest,
PreprocessedRequest,
SamplingOptions,
StopConditions,
TokenIdType,
)
# Utilities
from .sgl_utils import (
graceful_shutdown,
parse_sglang_args_inc,
reserve_free_port,
setup_native_endpoints,
)
__all__ = [
# Protocol types
"DisaggPreprocessedRequest",
"PreprocessedRequest",
"SamplingOptions",
"StopConditions",
"TokenIdType",
# Utilities
"parse_sglang_args_inc",
"reserve_free_port",
"graceful_shutdown",
"setup_native_endpoints",
# Base handlers
"BaseWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
class BaseWorkerHandler(ABC):
"""
Abstract base class for sglang request handlers. We use this to implement native sglang endpoints for
workers
"""
@abstractmethod
def __init__(
self,
engine: sgl.Engine,
server_args: ServerArgs,
component,
decode_client: Optional[Any] = None,
):
self.engine = engine
self.server_args = server_args
self.component = component
@abstractmethod
async def generate(self, request):
"""Generate tokens from the engine"""
...
async def flush_cache(self, request: dict):
"""Flush KV cache for each worker"""
_ = request
await self.engine.tokenizer_manager.flush_cache()
yield True
async def start_expert_distribution_record(self, request: dict):
"""
Start recording expert distribution.
"""
_ = request
await self.engine.tokenizer_manager.start_expert_distribution_record()
yield True
async def stop_expert_distribution_record(self, request: dict):
"""
Stop recording expert distribution.
"""
_ = request
await self.engine.tokenizer_manager.stop_expert_distribution_record()
yield True
async def dump_expert_distribution_record(self, request: dict):
"""
Dumps the expert distribution record to the directory specified in the environment variable `SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR`.
"""
_ = request
await self.engine.tokenizer_manager.dump_expert_distribution_record()
yield True
......@@ -15,6 +15,7 @@
import argparse
import contextlib
import logging
import socket
from argparse import Namespace
......@@ -55,3 +56,47 @@ def _reserve_disaggregation_bootstrap_port():
"""
with reserve_free_port() as port:
return port
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def setup_native_endpoints(server_args, component, handler):
"""Setup sgl native endpoints"""
# flush cache
flush_endpoint = component.endpoint("flush_cache")
tasks = []
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
# expert distribution endpoints
if server_args.expert_distribution_recorder_mode is not None:
start_expert_distribution_endpoint = component.endpoint(
"start_expert_distribution_record"
)
stop_expert_distribution_endpoint = component.endpoint(
"stop_expert_distribution_record"
)
dump_expert_distribution_endpoint = component.endpoint(
"dump_expert_distribution_record"
)
tasks.append(
start_expert_distribution_endpoint.serve_endpoint(
handler.start_expert_distribution_record
)
)
tasks.append(
stop_expert_distribution_endpoint.serve_endpoint(
handler.stop_expert_distribution_record
)
)
tasks.append(
dump_expert_distribution_endpoint.serve_endpoint(
handler.dump_expert_distribution_record
)
)
return tasks
......@@ -15,14 +15,19 @@ from sglang.srt.server_args import ServerArgs
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.utils.sgl_utils import parse_sglang_args_inc
from dynamo.sglang.common import (
BaseWorkerHandler,
graceful_shutdown,
parse_sglang_args_inc,
setup_native_endpoints,
)
configure_dynamo_logging()
class DecodeRequestHandler:
def __init__(self, engine: sgl.Engine):
self.engine = engine
class DecodeRequestHandler(BaseWorkerHandler):
def __init__(self, engine: sgl.Engine, server_args: ServerArgs, component):
super().__init__(engine, server_args, component)
logging.info("Decode request handler initialized")
async def generate(self, request: str):
......@@ -42,20 +47,6 @@ class DecodeRequestHandler:
async for result in results:
yield result
async def flush_cache(self, request: dict):
_ = request
asyncio.create_task(self.engine.tokenizer_manager.flush_cache())
yield {
"status": "success",
"message": "Cache flush initiated. Check backend logs for status",
}
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
......@@ -80,16 +71,16 @@ async def init(runtime: DistributedRuntime, server_args: ServerArgs):
engine = sgl.Engine(server_args=server_args)
handler = DecodeRequestHandler(engine)
component = runtime.namespace("dynamo").component("decode")
await component.create_service()
handler = DecodeRequestHandler(engine, server_args, component)
gen_endpoint = component.endpoint("generate")
flush_endpoint = component.endpoint("flush_cache")
tasks = [gen_endpoint.serve_endpoint(handler.generate)]
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
tasks.extend(setup_native_endpoints(server_args, component, handler))
await asyncio.gather(*tasks)
......
......@@ -8,6 +8,7 @@ import logging
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi.routing import APIRoute
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -25,7 +26,7 @@ class SglangHttpServer:
self.args = args
self.setup_routes()
async def _discover_endpoints(self):
async def _discover_endpoints(self, endpoint_name):
"""Discover endpoints that match the pattern"""
etcd_client = self.runtime.etcd_client()
if etcd_client is None:
......@@ -34,7 +35,7 @@ class SglangHttpServer:
prefix = "instances/"
kvs = await etcd_client.kv_get_prefix(prefix)
# Collect (namespace, component) combos that expose flush_cache
# Collect (namespace, component) combos that expose the target endpoint
discovered = set()
for kv in kvs:
key = kv["key"] if isinstance(kv, dict) else kv.key
......@@ -55,7 +56,7 @@ class SglangHttpServer:
continue
ep_name = ep_with_lease.split(":", 1)[0]
if ep_name == self.args.endpoint:
if ep_name == endpoint_name:
discovered.add((ns, comp))
logging.debug(f"Discovered endpoint: {ns}.{comp}")
......@@ -64,45 +65,100 @@ class SglangHttpServer:
)
return discovered
async def _dispatch_command(
self, endpoint_name: str, payload: dict | str = "{}", success_message: str = ""
):
"""Dispatches a command to all instances of a discovered endpoint."""
discovered = await self._discover_endpoints(endpoint_name=endpoint_name)
if not discovered:
return {"message": "No matching endpoints found", "success": False}
logging.debug(
f"Found components: {', '.join([f'{ns}.{comp}' for ns, comp in discovered])}"
)
for ns, comp in discovered:
ep = self.runtime.namespace(ns).component(comp).endpoint(endpoint_name)
client = await ep.client()
await client.wait_for_instances()
ids = client.instance_ids()
logging.debug(f"-- {ns}.{comp} : {len(ids)} instances --")
for inst_id in ids:
try:
stream = await client.direct(payload, inst_id)
async for stream_payload in stream:
logging.debug(f"[{ns}.{comp}][{inst_id}] -> {stream_payload}")
except Exception as e:
logging.error(
f"[{ns}.{comp}][{inst_id}] {endpoint_name} error: {e}"
)
return {"message": success_message, "success": True}
def setup_routes(self):
@self.app.post("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
endpoint_name = self.args.endpoint
try:
discovered = await self._discover_endpoints()
if not discovered:
return {"message": "No matching endpoints found", "success": False}
logging.debug(
f"Found components: {', '.join([f'{ns}.{comp}' for ns, comp in discovered])}"
return await self._dispatch_command(
endpoint_name,
success_message="Cache flush initiated",
)
for ns, comp in discovered:
ep = (
self.runtime.namespace(ns)
.component(comp)
.endpoint(self.args.endpoint)
)
client = await ep.client()
await client.wait_for_instances()
ids = client.instance_ids()
logging.debug(f"-- {ns}.{comp} : {len(ids)} instances --")
for inst_id in ids:
try:
stream = await client.direct("{}", inst_id)
async for payload in stream:
logging.debug(f"[{ns}.{comp}][{inst_id}] -> {payload}")
except Exception as e:
logging.error(f"[{ns}.{comp}][{inst_id}] flush error: {e}")
return {"message": "Cache flush initiated", "success": True}
except Exception as e:
logging.error(f"Cache flush error: {e}")
return {"message": f"Cache flush failed: {str(e)}", "success": False}
@self.app.post("/start_expert_distribution_record")
async def start_expert_distribution_record():
"""Start recording expert distribution."""
endpoint_name = "start_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording started",
)
except Exception as e:
logging.error(f"Start expert distribution error: {e}")
return {
"message": f"Start expert distribution failed: {str(e)}",
"success": False,
}
@self.app.post("/stop_expert_distribution_record")
async def stop_expert_distribution_record():
"""Stop recording expert distribution."""
endpoint_name = "stop_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording stopped",
)
except Exception as e:
logging.error(f"Stop expert distribution error: {e}")
return {
"message": f"Stop expert distribution failed: {str(e)}",
"success": False,
}
@self.app.post("/dump_expert_distribution_record")
async def dump_expert_distribution_record(request: dict):
"""Dump expert distribution recording to specified directory."""
endpoint_name = "dump_expert_distribution_record"
try:
return await self._dispatch_command(
endpoint_name,
success_message="Expert distribution recording dumped to directory",
)
except Exception as e:
logging.error(f"Dump expert distribution error: {e}")
return {
"message": f"Dump expert distribution failed: {str(e)}",
"success": False,
}
async def start_server(self):
"""Start the HTTP server"""
config = uvicorn.Config(
......@@ -112,10 +168,10 @@ class SglangHttpServer:
)
server = uvicorn.Server(config)
# Single nice log with available endpoints
logging.info(
f"🚀 SGL engine HTTP server running on http://0.0.0.0:{self.port} - Endpoints: POST /flush_cache"
)
# Debug: print all registered routes
for route in self.app.routes:
if isinstance(route, APIRoute):
logging.debug(f"Registered route: {route.methods} {route.path}")
await server.serve()
......@@ -135,9 +191,6 @@ def parse_args():
default=None,
help="Specify component name (default: discover all)",
)
p.add_argument(
"--endpoint", default=FLUSH_CACHE_ENDPOINT, help="Specify endpoint name"
)
return p.parse_args()
......
......@@ -27,13 +27,18 @@ from dynamo.llm import (
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.utils.protocol import DisaggPreprocessedRequest
from dynamo.sglang.utils.sgl_utils import parse_sglang_args_inc
from dynamo.sglang.common import (
BaseWorkerHandler,
DisaggPreprocessedRequest,
graceful_shutdown,
parse_sglang_args_inc,
setup_native_endpoints,
)
configure_dynamo_logging()
class RequestHandler:
class RequestHandler(BaseWorkerHandler):
def __init__(
self,
engine: sgl.Engine,
......@@ -41,9 +46,7 @@ class RequestHandler:
component,
decode_client: Optional[Any] = None,
):
self.engine = engine
self.server_args = server_args
self.component = component
super().__init__(engine, server_args, component, decode_client)
self.metrics_publisher = WorkerMetricsPublisher()
self.zmq_context = zmq.asyncio.Context() # type: ignore
......@@ -291,12 +294,6 @@ class RequestHandler:
}
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown
......@@ -368,8 +365,7 @@ async def init(
tasks = [endpoint.serve_endpoint(handler.generate)]
flush_endpoint = component.endpoint("flush_cache")
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
tasks.extend(setup_native_endpoints(server_args, component, handler))
await asyncio.gather(*tasks)
......
......@@ -146,6 +146,7 @@ addopts = [
"--ignore-glob=*_inc.py",
"--ignore-glob=*/llm/tensorrtllm*",
"--ignore-glob=docs/*",
"--ignore-glob=components/backends/sglang/src/dynamo/sglang/common/*",
# FIXME: Get relative/generic blob paths to work here
]
xfail_strict = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment