Unverified Commit 88f9c347 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)

parent fff10809
...@@ -70,7 +70,7 @@ class BaseKVSender(ABC): ...@@ -70,7 +70,7 @@ class BaseKVSender(ABC):
... ...
@abstractmethod @abstractmethod
def send(self, kv_indices: npt.NDArray[np.int64]): def send(self, kv_indices: npt.NDArray[np.int32]):
""" """
Send the kv cache at the given kv indices to the decoder server Send the kv cache at the given kv indices to the decoder server
""" """
...@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC): ...@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC):
): ... ): ...
@abstractmethod @abstractmethod
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
""" """
Notify the prefill server about the kv indices and aux index Notify the prefill server about the kv indices and aux index
""" """
......
...@@ -26,8 +26,8 @@ class FastQueue: ...@@ -26,8 +26,8 @@ class FastQueue:
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
"""Vectorised NumPy implementation.""" """Vectorised NumPy implementation."""
if src_indices.size == 0: if src_indices.size == 0:
return [], [] return [], []
......
...@@ -158,6 +158,7 @@ class DecodePreallocQueue: ...@@ -158,6 +158,7 @@ class DecodePreallocQueue:
bootstrap_port: int, bootstrap_port: int,
max_total_num_tokens: int, max_total_num_tokens: int,
prefill_pp_size: int, prefill_pp_size: int,
num_reserved_decode_tokens: int,
transfer_backend: TransferBackend, transfer_backend: TransferBackend,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
...@@ -178,9 +179,7 @@ class DecodePreallocQueue: ...@@ -178,9 +179,7 @@ class DecodePreallocQueue:
self.bootstrap_port = bootstrap_port self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = int( self.num_reserved_decode_tokens = num_reserved_decode_tokens
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
)
self.transfer_backend = transfer_backend self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation # Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = [] self.queue: List[DecodeRequest] = []
...@@ -404,7 +403,6 @@ class DecodePreallocQueue: ...@@ -404,7 +403,6 @@ class DecodePreallocQueue:
] ]
.cpu() .cpu()
.numpy() .numpy()
.astype(np.int64)
) )
decode_req.metadata_buffer_index = ( decode_req.metadata_buffer_index = (
......
...@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender): ...@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int32],
): ):
self.has_sent = True self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
......
...@@ -59,7 +59,7 @@ class KVTransferError(Exception): ...@@ -59,7 +59,7 @@ class KVTransferError(Exception):
@dataclasses.dataclass @dataclasses.dataclass
class TransferKVChunk: class TransferKVChunk:
room: int room: int
prefill_kv_indices: npt.NDArray[np.int64] prefill_kv_indices: npt.NDArray[np.int32]
index_slice: slice index_slice: slice
is_last: bool is_last: bool
prefill_aux_index: Optional[int] prefill_aux_index: Optional[int]
...@@ -72,7 +72,7 @@ class TransferInfo: ...@@ -72,7 +72,7 @@ class TransferInfo:
endpoint: str endpoint: str
dst_port: int dst_port: int
mooncake_session_id: str mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int dst_aux_index: int
required_dst_info_num: int required_dst_info_num: int
is_dummy: bool is_dummy: bool
...@@ -81,10 +81,10 @@ class TransferInfo: ...@@ -81,10 +81,10 @@ class TransferInfo:
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
if msg[4] == b"" and msg[5] == b"": if msg[4] == b"" and msg[5] == b"":
is_dummy = True is_dummy = True
dst_kv_indices = np.array([], dtype=np.int64) dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None dst_aux_index = None
else: else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64) dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii")) dst_aux_index = int(msg[5].decode("ascii"))
is_dummy = False is_dummy = False
return cls( return cls(
...@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
def send_kvcache( def send_kvcache(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor, executor: concurrent.futures.ThreadPoolExecutor,
): ):
# Group by indices # Group by indices
...@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
def add_transfer_request( def add_transfer_request(
self, self,
bootstrap_room: int, bootstrap_room: int,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int32],
index_slice: slice, index_slice: slice,
is_last: bool, is_last: bool,
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
...@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender): ...@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int32],
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
cls._socket_locks[endpoint] = threading.Lock() cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint] return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = ( self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
......
...@@ -44,7 +44,7 @@ class TransferInfo: ...@@ -44,7 +44,7 @@ class TransferInfo:
agent_metadata: bytes agent_metadata: bytes
agent_name: str agent_name: str
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int32]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_aux_index: int dst_aux_index: int
dst_gpu_id: int dst_gpu_id: int
...@@ -62,7 +62,7 @@ class TransferInfo: ...@@ -62,7 +62,7 @@ class TransferInfo:
agent_metadata=msg[3], agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"), agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")), dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")), dst_gpu_id=int(msg[9].decode("ascii")),
...@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager): ...@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
def send_kvcache( def send_kvcache(
self, self,
peer_name: str, peer_name: str,
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int, dst_gpu_id: int,
notif: str, notif: str,
): ):
...@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager): ...@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
def add_transfer_request( def add_transfer_request(
self, self,
bootstrap_room: int, bootstrap_room: int,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int32],
index_slice: slice, index_slice: slice,
is_last: bool, is_last: bool,
chunk_id: int, chunk_id: int,
...@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender): ...@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int32],
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.started_transfer = False self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = ( self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
......
...@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin:
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu() .cpu()
.numpy() .numpy()
.astype(np.int64)
) )
req.start_send_idx = end_idx req.start_send_idx = end_idx
if last_chunk: if last_chunk:
......
...@@ -656,6 +656,7 @@ class Scheduler( ...@@ -656,6 +656,7 @@ class Scheduler(
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens, max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp, prefill_pp_size=self.server_args.disaggregation_prefill_pp,
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
transfer_backend=self.transfer_backend, transfer_backend=self.transfer_backend,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment