Unverified Commit 88f9c347 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)

parent fff10809
......@@ -70,7 +70,7 @@ class BaseKVSender(ABC):
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int64]):
def send(self, kv_indices: npt.NDArray[np.int32]):
"""
Send the kv cache at the given kv indices to the decoder server
"""
......@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC):
): ...
@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
"""
Notify the prefill server about the kv indices and aux index
"""
......
......@@ -26,8 +26,8 @@ class FastQueue:
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
......
......@@ -158,6 +158,7 @@ class DecodePreallocQueue:
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
num_reserved_decode_tokens: int,
transfer_backend: TransferBackend,
):
self.req_to_token_pool = req_to_token_pool
......@@ -178,9 +179,7 @@ class DecodePreallocQueue:
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = int(
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
)
self.num_reserved_decode_tokens = num_reserved_decode_tokens
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
......@@ -404,7 +403,6 @@ class DecodePreallocQueue:
]
.cpu()
.numpy()
.astype(np.int64)
)
decode_req.metadata_buffer_index = (
......
......@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
):
self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
......
......@@ -59,7 +59,7 @@ class KVTransferError(Exception):
@dataclasses.dataclass
class TransferKVChunk:
room: int
prefill_kv_indices: npt.NDArray[np.int64]
prefill_kv_indices: npt.NDArray[np.int32]
index_slice: slice
is_last: bool
prefill_aux_index: Optional[int]
......@@ -72,7 +72,7 @@ class TransferInfo:
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int64]
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
is_dummy: bool
......@@ -81,10 +81,10 @@ class TransferInfo:
def from_zmq(cls, msg: List[bytes]):
if msg[4] == b"" and msg[5] == b"":
is_dummy = True
dst_kv_indices = np.array([], dtype=np.int64)
dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None
else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii"))
is_dummy = False
return cls(
......@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64],
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
......@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
aux_index: Optional[int] = None,
......@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
......
......@@ -44,7 +44,7 @@ class TransferInfo:
agent_metadata: bytes
agent_name: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_ptrs: list[int]
dst_aux_index: int
dst_gpu_id: int
......@@ -62,7 +62,7 @@ class TransferInfo:
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
......@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
def send_kvcache(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int64],
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
):
......@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
chunk_id: int,
......@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
......
......@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin:
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
.astype(np.int64)
)
req.start_send_idx = end_idx
if last_chunk:
......
......@@ -656,6 +656,7 @@ class Scheduler(
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
transfer_backend=self.transfer_backend,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment