Unverified Commit 7d316991 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Update prefill.py (#7190)

parent ab1a4fa5
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.disaggregation.utils import DisaggregationMode
class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
kv_data_ptrs: List[int]
kv_data_lens: List[int]
kv_item_lens: List[int]
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
ib_device: str
ib_traffic_class: str
gpu_id: int
# for different tp
decode_tp_size: int
# for pp prefill
prefill_pp_size: int
class KVPoll:
......@@ -45,7 +54,12 @@ class BaseKVSender(ABC):
@abstractmethod
def __init__(
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
): ...
@abstractmethod
......
import threading
from collections import deque
from typing import List, Tuple
import numpy as np
import numpy.typing as npt
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
......@@ -33,8 +33,8 @@ from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
......@@ -207,7 +207,7 @@ class DecodePreallocQueue:
def add(self, req: Req) -> None:
"""Add a request to the pending queue."""
if req.bootstrap_host == FakeBootstrapHost:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
# Fake transfer for warmup reqs
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
else:
......
......@@ -17,7 +17,14 @@ logger = logging.getLogger(__name__)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class FakeKVSender(BaseKVSender):
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.has_sent = False
def poll(self) -> KVPoll:
......
......@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
from sglang.srt.disaggregation.common.utils import (
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_free_port,
......@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
class MooncakeKVSender(BaseKVSender):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
self,
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
......
......@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
CommonKVManager,
CommonKVReceiver,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_by_remote
......@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
class NixlKVSender(BaseKVSender):
def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
def __init__(
self,
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
......
......@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
......@@ -51,7 +51,6 @@ if TYPE_CHECKING:
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache
logger = logging.getLogger(__name__)
......@@ -68,35 +67,45 @@ class PrefillBootstrapQueue:
metadata_buffers: MetadataBuffers,
tp_rank: int,
tp_size: int,
gpu_id: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
transfer_backend: TransferBackend,
max_total_num_tokens: int,
decode_tp_size: int,
decode_dp_size: int,
scheduler: Scheduler,
pp_rank: int,
pp_size: int,
transfer_backend: TransferBackend,
):
self.token_to_kv_pool = token_to_kv_pool
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.transfer_backend = transfer_backend
self.scheduler = scheduler
self.kv_manager = self._init_kv_manager()
self.decode_tp_size = decode_tp_size
self.decode_dp_size = decode_dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.queue: List[Req] = []
self.pp_rank = pp_rank
self.pp_size = pp_size
self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port
def store_prefill_results(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative"
output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id
self.max_total_num_tokens = max_total_num_tokens
self.scheduler = scheduler
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
kv_args.engine_rank = self.tp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
......@@ -115,12 +124,12 @@ class PrefillBootstrapQueue:
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
# Define req -> input ids buffer
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(
kv_args,
......@@ -130,23 +139,39 @@ class PrefillBootstrapQueue:
)
return kv_manager
def add(self, req: Req) -> None:
if req.bootstrap_host == FakeBootstrapHost:
# Fake transfer for warmup reqs
def add(self, req: Req, num_kv_heads: int) -> None:
if self._check_if_req_exceed_kv_capacity(req):
return
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
else:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
dest_tp_ranks = [self.tp_rank]
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
dest_tp_ranks=dest_tp_ranks,
pp_rank=self.pp_rank,
)
self._process_req(req)
self.queue.append(req)
def extend(self, reqs: List[Req]) -> None:
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs:
self.add(req)
self.add(req, num_kv_heads)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def _process_req(self, req: Req) -> None:
"""
......@@ -154,19 +179,40 @@ class PrefillBootstrapQueue:
"""
req.sampling_params.max_new_tokens = 1
def pop_bootstrapped(self) -> List[Req]:
"""pop the reqs which has finished bootstrapping"""
def pop_bootstrapped(
self,
return_failed_reqs: bool = False,
rids_to_check: Optional[List[str]] = None,
) -> List[Req]:
"""
pop the reqs which has finished bootstrapping
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
bootstrapped_reqs = []
failed_reqs = []
indices_to_remove = set()
if len(self.queue) == 0:
return []
if return_failed_reqs is False:
return []
else:
return [], []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check:
continue
# Either waiting for input or failed
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
if poll == KVPoll.Bootstrapping:
continue
elif poll == KVPoll.Failed:
......@@ -181,9 +227,10 @@ class PrefillBootstrapQueue:
)
self.scheduler.stream_output([req], req.return_logprob)
indices_to_remove.add(i)
failed_reqs.append(req)
continue
# KV.WaitingForInput
# KV.WaitingForInput - init here
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break
......@@ -192,9 +239,9 @@ class PrefillBootstrapQueue:
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert req.metadata_buffer_index is not None
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
......@@ -202,7 +249,10 @@ class PrefillBootstrapQueue:
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return bootstrapped_reqs
if return_failed_reqs is False:
return bootstrapped_reqs
else:
return bootstrapped_reqs, failed_reqs
class SchedulerDisaggregationPrefillMixin:
......@@ -211,7 +261,7 @@ class SchedulerDisaggregationPrefillMixin:
"""
@torch.no_grad()
def event_loop_normal_disagg_prefill(self: Scheduler):
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
......@@ -229,7 +279,6 @@ class SchedulerDisaggregationPrefillMixin:
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
......@@ -250,7 +299,7 @@ class SchedulerDisaggregationPrefillMixin:
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self: Scheduler):
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
self.result_queue = deque()
while True:
......@@ -268,9 +317,7 @@ class SchedulerDisaggregationPrefillMixin:
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
......@@ -287,6 +334,9 @@ class SchedulerDisaggregationPrefillMixin:
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
......@@ -309,7 +359,7 @@ class SchedulerDisaggregationPrefillMixin:
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Adapted from process_batch_result_prefill
"""
(
......@@ -325,7 +375,7 @@ class SchedulerDisaggregationPrefillMixin:
)
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap:
# wait
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
......@@ -397,11 +447,15 @@ class SchedulerDisaggregationPrefillMixin:
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None
) -> List[Req]:
"""
Poll the requests in the middle of transfer. If done, return the request.
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
assert len(self.disagg_prefill_inflight_queue) > 0
if len(self.disagg_prefill_inflight_queue) == 0:
return []
done_reqs = []
......@@ -413,6 +467,14 @@ class SchedulerDisaggregationPrefillMixin:
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if rids_to_check is not None:
if req.rid not in rids_to_check:
undone_reqs.append(req)
continue
assert poll == KVPoll.Success or poll == KVPoll.Failed
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
......@@ -434,11 +496,8 @@ class SchedulerDisaggregationPrefillMixin:
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
done_reqs.append(req)
for req in done_reqs:
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
req.metadata_buffer_index
)
else:
assert False, f"Unexpected polling state {poll=}"
# Stream requests which have finished transfer
self.stream_output(
......@@ -446,9 +505,32 @@ class SchedulerDisaggregationPrefillMixin:
any(req.return_logprob for req in done_reqs),
None,
)
for req in done_reqs:
req: Req
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
self.disagg_prefill_inflight_queue = undone_reqs
return done_reqs
def get_transferred_rids(self: Scheduler) -> List[str]:
"""
Used by PP, get the transferred rids but **do not pop**
"""
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_group().cpu_group,
)
transferred_rids: List[str] = []
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if poll == KVPoll.Success or poll == KVPoll.Failed:
transferred_rids.append(req.rid)
return transferred_rids
def process_prefill_chunk(self: Scheduler) -> None:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
......
......@@ -14,15 +14,15 @@ import requests
import torch
import torch.distributed as dist
from sglang.srt.utils import get_ip, get_local_ip_by_remote
from sglang.srt.utils import get_ip
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
FakeBootstrapHost = "2.2.2.2"
# env var for testing failure, convert to float explicitly
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
#########################
# Constants & Enums
#########################
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
class DisaggregationMode(Enum):
......@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
DECODE = "decode"
#########################
# Synchronization
#########################
# env var for testing failure, convert to float explicitly
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
def poll_and_all_reduce(pollers, gloo_group):
# at a certain prob, the poll is failed to simulate failure
if FAILURE_PROB > 0:
......@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
return tensor_to_reduce.tolist()
#########################
# Metadata Buffers
#########################
class ReqToMetadataIdxAllocator:
"""A memory pool that maps a request to its first output token location."""
......@@ -70,6 +83,91 @@ class ReqToMetadataIdxAllocator:
self.free_slots.append(free_index)
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
#########################
# Transfer Backend
#########################
class TransferBackend(Enum):
MOONCAKE = "mooncake"
NIXL = "nixl"
......@@ -77,6 +175,7 @@ class TransferBackend(Enum):
class KVClassType(Enum):
KVARGS = "kvargs"
MANAGER = "manager"
SENDER = "sender"
RECEIVER = "receiver"
......@@ -87,6 +186,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
......@@ -95,6 +195,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: (MooncakeKVReceiver),
......@@ -102,6 +203,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
NixlKVManager,
......@@ -110,6 +212,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: (NixlKVReceiver),
......@@ -117,9 +220,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.SENDER: FakeKVSender,
KVClassType.RECEIVER: (FakeKVReceiver),
}
......@@ -128,6 +233,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
#########################
# KV Pages
#########################
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaranteed to be full except the last page.
# 2. page index = kv_index // page_size
......@@ -143,6 +253,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return (num_kv_indices + page_size - 1) // page_size
#########################
# PDLB Registry
#########################
@dataclasses.dataclass
class PDRegistryRequest:
"""A request to register a machine itself to the LB."""
......@@ -181,6 +296,11 @@ def register_disaggregation_server(
)
#########################
# Misc
#########################
def is_mla_backend(target_kv_pool) -> bool:
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
......@@ -200,119 +320,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = []
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
......@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import (
FakeBootstrapHost,
FAKE_BOOTSTRAP_HOST,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
......@@ -878,7 +878,7 @@ def _wait_and_warmup(
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
......
......@@ -619,7 +619,7 @@ class Scheduler(
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
......@@ -627,7 +627,7 @@ class Scheduler(
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
......@@ -642,7 +642,7 @@ class Scheduler(
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
......@@ -660,7 +660,7 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
......@@ -672,14 +672,20 @@ class Scheduler(
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.attn_tp_cpu_group,
transfer_backend=self.transfer_backend,
max_total_num_tokens=self.max_total_num_tokens,
decode_tp_size=self.server_args.disaggregation_decode_tp,
decode_dp_size=self.server_args.disaggregation_decode_dp,
scheduler=self,
pp_rank=self.pp_rank,
pp_size=self.pp_size,
transfer_backend=self.transfer_backend,
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
......@@ -1110,7 +1116,9 @@ class Scheduler(
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.add(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
......@@ -1118,7 +1126,9 @@ class Scheduler(
def _extend_requests_to_queue(self, reqs: List[Req]):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(reqs)
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)
......
......@@ -227,6 +227,9 @@ class ServerArgs:
disaggregation_mode: str = "null"
disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998
disaggregation_decode_tp: Optional[int] = None
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
......@@ -505,12 +508,27 @@ class ServerArgs:
self.triton_attention_num_kv_splits = 16
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
elif self.disaggregation_mode == "decode":
if self.disaggregation_mode == "decode":
assert (
self.disaggregation_decode_tp is None
), "Cannot set --disaggregation-decode-tp for the decode engine."
assert (
self.disaggregation_decode_dp is None
), "Cannot set --disaggregation-decode-dp for the decode engine."
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")
elif self.disaggregation_mode == "prefill":
if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size
if self.disaggregation_decode_dp is None:
self.disaggregation_decode_dp = self.dp_size
self.disaggregation_prefill_pp = self.pp_size
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
......@@ -520,6 +538,14 @@ class ServerArgs:
"1" if self.disable_outlines_disk_cache else "0"
)
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
assert larger_tp % smaller_tp == 0, (
"Different tp size is supported only when one tp is multiple of the other. "
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
......@@ -1512,6 +1538,24 @@ class ServerArgs:
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
parser.add_argument(
"--disaggregation-decode-tp",
type=int,
default=ServerArgs.disaggregation_decode_tp,
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-decode-dp",
type=int,
default=ServerArgs.disaggregation_decode_dp,
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-prefill-pp",
type=int,
default=ServerArgs.disaggregation_prefill_pp,
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
)
parser.add_argument(
"--disaggregation-ib-device",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment