Unverified Commit b43c72a5 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sglang): disaggregated support (#976)


Co-authored-by: default avatarishandhanani <ishandhananai@gmail.com>
parent c42b1a9a
......@@ -20,7 +20,7 @@ limitations under the License.
This directory contains examples and reference implementations for deploying Large Language Models (LLMs) in various configurations using SGLang. SGLang internally uses ZMQ to communicate between the ingress and the engine processes. For Dynamo, we leverage the runtime to communicate directly with the engine processes and handle ingress and pre/post processing on our end.
> [!IMPORTANT]
> In order to run these examples, you will need to install sglang using `uv pip install "sglang[all]>=0.4.6.post2"`. Additionally, SGLang currently does not have pre-built wheels for ARM. If you are on an ARM machine - you will need to install SGLang from source.
> In order to run these examples, you will need to install sglang using `uv pip install "sglang[all]>=0.4.6.post4"`. Additionally, SGLang currently does not have pre-built wheels for ARM. If you are on an ARM machine - you will need to install SGLang from source.
## Deployment Architectures
......@@ -61,3 +61,22 @@ docker compose -f deploy/docker-compose.yml up -d
cd /workspace/examples/sglang
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
```
#### Disaggregated
As of `sglang==0.4.6.post4`, SGLang uses a mini load balancer to route requests to handle disaggregated serving. The load balancer functions as follows
1. The load balancer receives a request from the client
2. A random `(prefill, decode)` pair is selected from the pool of available workers
3. Request is sent to both `prefill` and `decode` workers via asyncio tasks
4. Internally disaggregation is done from prefill -> decode
Because Dynamo has a discovery mechanism, we do not use a load balancer. Instead, we first route to a random prefill worker, select a random decode worker, and then send the request to both. Internally, SGLang's bootstrap server (which is a part of the `tokenizer_manager`) is used in conjuction with NIXL to handle the kv transfer.
> [!IMPORTANT]
> Disaggregated serving in SGLang currently requires each worker to have the same tensor parallel size [unless you are using an MLA based model](https://github.com/sgl-project/sglang/pull/5922)
```bash
cd /workspace/examples/sglang
dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import sglang as sgl
from utils.protocol import DisaggPreprocessedRequest
from utils.sglang import parse_sglang_args
from dynamo.sdk import dynamo_endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1},
workers=1,
)
class SGLangDecodeWorker:
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_sglang_args(class_name, "")
self.engine = sgl.Engine(server_args=self.engine_args)
logger.warning("Decode worker initialized")
@dynamo_endpoint()
async def generate(self, req: DisaggPreprocessedRequest):
g = await self.engine.async_generate(
input_ids=req.request.token_ids,
sampling_params=req.sampling_params,
stream=True,
bootstrap_host=req.bootstrap_host,
bootstrap_port=req.bootstrap_port,
bootstrap_room=req.bootstrap_room,
)
async for result in g:
yield result
......@@ -24,15 +24,19 @@ For now - the SGLangWorker will be responsible for aggreagted and prefill and we
have a separate DecodeWorker.
"""
import asyncio
import logging
import signal
import random
import socket
import sglang as sgl
from utils.protocol import PreprocessedRequest
from components.decode_worker import SGLangDecodeWorker
from sglang.srt.utils import get_ip
from utils.protocol import DisaggPreprocessedRequest, PreprocessedRequest
from utils.sglang import parse_sglang_args
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
......@@ -45,14 +49,13 @@ logger = logging.getLogger(__name__)
workers=1,
)
class SGLangWorker:
decode_worker = depends(SGLangDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_sglang_args(class_name, "")
self.engine = sgl.Engine(server_args=self.engine_args)
for sig in [signal.SIGINT, signal.SIGTERM]:
signal.signal(sig, self.shutdown_sglang_engine)
logger.info("SGLangWorker initialized")
@async_on_start
......@@ -67,10 +70,33 @@ class SGLangWorker:
self.engine_args.model_path,
self.engine_args.served_model_name,
)
def shutdown_sglang_engine(self, signum, frame):
self.engine.shutdown()
logger.info("SGLang engine shutdown")
if self.engine_args.disaggregation_mode:
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
comp_ns, comp_name = SGLangDecodeWorker.dynamo_address() # type: ignore
self.decode_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
def _get_bootstrap_info(self):
"""
Bootstrap info is stored in the worker's tokenizer manager. We use it to
add servers to the bootstrap_room
"""
inner_tm = self.engine.tokenizer_manager
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
# multinode check
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_ip()
return bootstrap_host, bootstrap_port
def _build_sampling_params(self, request: PreprocessedRequest) -> dict:
# TODO: maintain a full mapping from PreprocessedRequest to SGLang's SamplingParams
......@@ -90,19 +116,63 @@ class SGLangWorker:
async def generate(self, request: PreprocessedRequest):
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params = self._build_sampling_params(request)
g = await self.engine.async_generate(
input_ids=request.token_ids,
sampling_params=sampling_params,
stream=True,
)
if self.engine_args.disaggregation_mode != "null":
bootstrap_room = self._generate_bootstrap_room()
# decode worker request
disagg_request = DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
# prefill response is not used
prefill = await self.engine.async_generate(
input_ids=request.token_ids,
sampling_params=sampling_params,
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
prefill_task = asyncio.create_task(self._prefill_generator(prefill))
decode = await self.decode_client.generate(disagg_request.model_dump_json())
async for out in self._process_stream(decode, unpack=True):
yield out
await prefill_task
else:
g = await self.engine.async_generate(
input_ids=request.token_ids,
sampling_params=sampling_params,
stream=True,
)
async for out in self._process_stream(g, unpack=False):
yield out
async def _process_stream(self, stream_source, unpack: bool):
num_output_tokens_so_far = 0
async for res in g:
finish_reason = res["meta_info"]["finish_reason"]
async for res in stream_source:
data = res.data() if unpack else res
finish_reason = data["meta_info"]["finish_reason"]
if finish_reason:
# Don't forward the stop token
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
next_total_toks = len(res["output_ids"])
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
next_total_toks = len(data["output_ids"])
out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
yield out
num_output_tokens_so_far = next_total_toks
def _generate_bootstrap_room(self):
return random.randint(0, 2**63 - 1)
async def _prefill_generator(self, prefill):
async for _ in prefill:
pass
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.SGLangWorker.generate
port: 8000
# We set disaggregation-bootstrap-port in utils/sglang.py to ensure unique ports for each replica
SGLangWorker:
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
served-model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
tp: 1
trust-remote-code: true
skip-tokenizer-init: true
disaggregation-mode: prefill
disaggregation-transfer-backend: nixl
ServiceArgs:
workers: 1
resources:
gpu: 1
SGLangDecodeWorker:
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
served-model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
tp: 1
trust-remote-code: true
skip-tokenizer-init: true
disaggregation-mode: decode
disaggregation-transfer-backend: nixl
ServiceArgs:
workers: 1
resources:
gpu: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.decode_worker import SGLangDecodeWorker
from components.frontend import Frontend
from components.worker import SGLangWorker
Frontend.link(SGLangWorker).link(SGLangDecodeWorker)
......@@ -52,3 +52,11 @@ class PreprocessedRequest(BaseModel):
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
class DisaggPreprocessedRequest(BaseModel):
request: PreprocessedRequest
sampling_params: dict
bootstrap_host: str
bootstrap_port: int
bootstrap_room: int
......@@ -14,9 +14,11 @@
# limitations under the License.
import argparse
from argparse import Namespace
from sglang.srt.server_args import ServerArgs
from dynamo.sdk.cli.utils import reserve_free_port
from dynamo.sdk.lib.config import ServiceConfig
......@@ -24,9 +26,23 @@ def parse_sglang_args(service_name, prefix) -> ServerArgs:
config = ServiceConfig.get_instance()
sglang_args = config.as_args(service_name, prefix=prefix)
parser = argparse.ArgumentParser()
bootstrap_port = _reserve_disaggregation_bootstrap_port()
# add future dynamo arguments here
ServerArgs.add_cli_args(parser)
args = parser.parse_args(sglang_args)
args_dict = vars(args)
args_dict["disaggregation_bootstrap_port"] = bootstrap_port
args = Namespace(**args_dict)
return ServerArgs.from_cli_args(args)
def _reserve_disaggregation_bootstrap_port():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
We use an existing utility function that reserves a free port on your
machine to avoid collisions.
"""
with reserve_free_port() as port:
return port
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment