Unverified Commit 75135580 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Add doc and simplify sender.send (#6019)

parent 4d643f6c
# PD Disaggregation
## Why and What is PD Disaggregation?
Large Language Model (LLM) inference comprises two distinct phases: **Prefill** and **Decode**. The Prefill phase is computation-intensive, processing the entire input sequence, while the Decode phase is memory-intensive, managing the Key-Value (KV) cache for token generation. Traditionally, these phases are handled within a unified engine, where combined scheduling of prefill and decode batches introduces inefficiencies. To address these challenges, we introduce **Prefill and Decoding (PD) Disaggregation** in SGLang.
### Issues with Unified Scheduling
The conventional unified engine, which processes prefill and decode batches together, results in two significant problems:
1. **Prefill Interruption**: Incoming prefill batches frequently interrupt ongoing decode batches, causing substantial delays in token generation.
2. **DP Attention Imbalance**: In data-parallel (DP) attention, one DP worker may process a prefill batch while another handles a decode batch simultaneously, leading to increased decode latency.
PD Disaggregation resolves these by separating the two stages, enabling tailored optimizations for each.
For the design details, please refer to [link](https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0).
Currently, we support Mooncake and NIXL as the transfer engine.
## Mooncake
### Requirements
```bash
uv pip install mooncake-transfer-engine
```
### Usage
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
```
...@@ -55,6 +55,7 @@ The core features include: ...@@ -55,6 +55,7 @@ The core features include:
backend/custom_chat_template.md backend/custom_chat_template.md
backend/quantization.md backend/quantization.md
backend/lora.ipynb backend/lora.ipynb
backend/pd_disaggregation.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender): ...@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
self, self,
kv_indices: list[int], kv_indices: list[int],
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
dest_ranks: Optional[list[int]] = None,
): ):
logger.info( logger.info(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}" f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
) )
pass pass
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
): ):
logger.info(
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
)
if is_last:
self.has_sent = True self.has_sent = True
logger.info(f"FakeKVSender send success") logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
else:
self.has_sent = False
logger.info(f"FakeKVSender send fake transferring")
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
......
...@@ -464,6 +464,8 @@ class MooncakeKVSender(BaseKVSender): ...@@ -464,6 +464,8 @@ class MooncakeKVSender(BaseKVSender):
self.aux_index = None self.aux_index = None
self.bootstrap_server_url = bootstrap_addr self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
...@@ -472,9 +474,11 @@ class MooncakeKVSender(BaseKVSender): ...@@ -472,9 +474,11 @@ class MooncakeKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
if not is_last: if not is_last:
self.kv_mgr.add_transfer_request( self.kv_mgr.add_transfer_request(
self.bootstrap_room, kv_indices, index_slice, False self.bootstrap_room, kv_indices, index_slice, False
......
...@@ -384,11 +384,10 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -384,11 +384,10 @@ class SchedulerDisaggregationPrefillMixin:
if end_idx is not None if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids)) else min(len(req.fill_ids), len(req.origin_input_ids))
) )
last_chunk = token_id is not None last_chunk = token_id is not None
if (not last_chunk) and ( if not last_chunk:
end_idx % page_size != 0
): # todo: remove the second condition
# if not the last chunk and the last page is partial, delay the last partial page to the next send # if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size end_idx = end_idx - end_idx % page_size
...@@ -405,16 +404,10 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -405,16 +404,10 @@ class SchedulerDisaggregationPrefillMixin:
req.metadata_buffer_index, token_id req.metadata_buffer_index, token_id
) )
page_indices = kv_to_page_indices(kv_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size)
page_start_idx = start_idx // page_size
page_end_idx = page_start_idx + len(page_indices)
if len(page_indices) == 0: if len(page_indices) == 0:
logger.info( logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
) )
return return
req.disagg_kv_sender.send( req.disagg_kv_sender.send(page_indices)
page_indices, slice(page_start_idx, page_end_idx), last_chunk
)
...@@ -407,6 +407,7 @@ class GenerateReqInput: ...@@ -407,6 +407,7 @@ class GenerateReqInput:
else None else None
), ),
return_hidden_states=self.return_hidden_states, return_hidden_states=self.return_hidden_states,
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
bootstrap_host=( bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None self.bootstrap_host[i] if self.bootstrap_host is not None else None
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment