Unverified Commit 5c214257 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Raise error for incompatible mooncake version and some minor fixes (#7527)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b8df43ab
......@@ -56,7 +56,7 @@ PD Disaggregation with Mooncake supports the following environment variables for
|:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `30` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `120` |
#### Decode Server Configuration
| Variable | Description | Default |
......
......@@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
).start()
self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {}
......@@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager):
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock()
self.required_prefill_info_num_map: Dict[int, int] = {}
self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set)
self.required_prefill_response_num_table: Dict[int, int] = {}
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
......@@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager):
each page to ensure correctness for any page_size and head-slicing configuration.
This may introduce performance overhead (increased TTFT) for long sequences.
"""
# rank/kv_head config
# Extract configuration
local_tp_rank = self.kv_args.engine_rank
local_tp_size = self.tp_size // self.dp_size
num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size
# Calculate head distribution
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
heads_per_prefill_rank = num_kv_heads
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
# decode config
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
# Determine slicing parameters based on TP configuration
if local_tp_size > dst_tp_size:
src_head_offset = 0
num_heads_to_send = heads_per_prefill_rank
......@@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager):
for layer_id in range(num_layers):
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
# Page stride on the target Decode rank for its slice pages
# Page stride on the target dst decode rank for its slice pages
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
......@@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager):
)
return -1
# Calculate precise byte offset and length for the sub-slice within Prefill page data
# Calculate precise byte offset and length for the sub-slice within the prefill page data
src_slice_offset = src_head_offset * bytes_per_head
dst_slice_offset = dst_head_offset * bytes_per_head
slice_lens_per_page = num_heads_to_send * bytes_per_head
# Sanity check: The data sub-slice we intend to send should fit into D_n's page.
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
# This means slice_lens_per_page <= item_len_of_decode_rank_page
if slice_lens_per_page > item_len_of_decode_rank_page:
logger.error(
......@@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager):
return -1
layer_transfer_params.append(
(
self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads)
dst_kv_ptrs[
layer_id
], # Decode base ptr (for its slice for this layer)
item_len_of_prefill_rank_page, # Prefill page size (all heads)2048
item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024
src_slice_offset, # Offset to slice data in Prefill page
dst_slice_offset, # Offset to slice data in Decode page
slice_lens_per_page, # Length of slice data per page (actual data to send)
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
item_len_of_prefill_rank_page,
item_len_of_decode_rank_page,
src_slice_offset,
dst_slice_offset,
slice_lens_per_page,
)
)
......@@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])
# Get the starting memory address for the current source and destination pages
# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate start address of the current token slot
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
......@@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager):
+ token_slot_in_page * bytes_per_token_on_decode
)
# Calculate final source and destination addresses by applying head-slice offsets
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_offset
dst_slice_addr = dst_token_slot_start_addr + dst_offset
......@@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
target_rank_registration_info.dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
......@@ -675,19 +672,19 @@ class MooncakeKVManager(BaseKVManager):
prefill_rank = int(prefill_rank.decode("ascii"))
if status == KVPoll.Success:
# record arrived prefill_rank
self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank)
expected_prefill_num = self.required_prefill_info_num_map[
bootstrap_room
]
arrived_prefill_num = len(
self.decode_kv_arrive_state[bootstrap_room]
)
if (
self.is_mla_backend
or arrived_prefill_num == expected_prefill_num
):
self.update_status(bootstrap_room, KVPoll.Success)
if bootstrap_room in self.request_status:
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
expected_response_num = (
self.required_prefill_response_num_table[bootstrap_room]
)
arrived_response_num = len(
self.prefill_response_tracker[bootstrap_room]
)
if (
self.is_mla_backend
or arrived_response_num == expected_response_num
):
self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed:
self.record_failure(
bootstrap_room,
......@@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.conclude_state = None
self.init_time = None
self.init_time = time.time()
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
self.init_time = time.time()
def send(
self,
......@@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
self.required_dst_info_num = 1
self.required_prefill_info_num = 1
self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
if not self.kv_mgr.is_mla_backend:
......@@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
)
self.required_prefill_info_num = 1
self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank]
else:
if not self.kv_mgr.is_mla_backend:
......@@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.required_prefill_info_num = (
self.required_prefill_response_num = (
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
)
......@@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = (
self.required_prefill_info_num
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
)
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
......@@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend:
# MLA :select one prefill rank as real rank
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
else:
# no-MLA:select all prefill ranks
# For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
......@@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
def clear(self) -> None:
if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room)
self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room)
self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room)
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
def failure_exception(self):
# Explicitly set the status to failure since this request has failed in another rank
......
......@@ -97,13 +97,19 @@ class MooncakeTransferEngine:
peer_buffer_addresses: List[int],
lengths: List[int],
) -> int:
"""Synchronously transfer data to the specified address."""
"""Synchronously transfer data to the specified addresses in batches."""
try:
ret = self.engine.batch_transfer_sync_write(
session_id, buffers, peer_buffer_addresses, lengths
)
except Exception:
ret = -1
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
if not hasattr(self.engine, "batch_transfer_sync_write"):
raise RuntimeError(
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
)
if ret < 0:
logger.debug(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment