Unverified Commit 5c214257 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Raise error for incompatible mooncake version and some minor fixes (#7527)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b8df43ab
...@@ -56,7 +56,7 @@ PD Disaggregation with Mooncake supports the following environment variables for ...@@ -56,7 +56,7 @@ PD Disaggregation with Mooncake supports the following environment variables for
|:--------:|:-----------:|:--------: |:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions | | **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` | | **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `30` | | **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `120` |
#### Decode Server Configuration #### Decode Server Configuration
| Variable | Description | Default | | Variable | Description | Default |
......
...@@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
).start() ).start()
self.bootstrap_time_out = get_int_env_var( self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
) )
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {} self.heartbeat_failures = {}
...@@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager):
self.session_pool_lock = threading.Lock() self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set) self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
self.required_prefill_info_num_map: Dict[int, int] = {} self.required_prefill_response_num_table: Dict[int, int] = {}
self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set) self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
# Heartbeat interval should be at least 2 seconds # Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max( self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
...@@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager): ...@@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager):
each page to ensure correctness for any page_size and head-slicing configuration. each page to ensure correctness for any page_size and head-slicing configuration.
This may introduce performance overhead (increased TTFT) for long sequences. This may introduce performance overhead (increased TTFT) for long sequences.
""" """
# rank/kv_head config # Extract configuration
local_tp_rank = self.kv_args.engine_rank local_tp_rank = self.kv_args.engine_rank
local_tp_size = self.tp_size // self.dp_size local_tp_size = self.tp_size // self.dp_size
num_kv_heads = self.kv_args.kv_head_num num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size page_size = self.kv_args.page_size
# Calculate head distribution
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
heads_per_prefill_rank = num_kv_heads heads_per_prefill_rank = num_kv_heads
decode_global_head_start = dst_tp_rank * heads_per_decode_rank decode_global_head_start = dst_tp_rank * heads_per_decode_rank
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
# decode config
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)] decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
# Determine slicing parameters based on TP configuration
if local_tp_size > dst_tp_size: if local_tp_size > dst_tp_size:
src_head_offset = 0 src_head_offset = 0
num_heads_to_send = heads_per_prefill_rank num_heads_to_send = heads_per_prefill_rank
...@@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager):
for layer_id in range(num_layers): for layer_id in range(num_layers):
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
# Page stride on the target Decode rank for its slice pages # Page stride on the target dst decode rank for its slice pages
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
...@@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager):
) )
return -1 return -1
# Calculate precise byte offset and length for the sub-slice within Prefill page data # Calculate precise byte offset and length for the sub-slice within the prefill page data
src_slice_offset = src_head_offset * bytes_per_head src_slice_offset = src_head_offset * bytes_per_head
dst_slice_offset = dst_head_offset * bytes_per_head dst_slice_offset = dst_head_offset * bytes_per_head
slice_lens_per_page = num_heads_to_send * bytes_per_head slice_lens_per_page = num_heads_to_send * bytes_per_head
# Sanity check: The data sub-slice we intend to send should fit into D_n's page. # Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
# This means slice_lens_per_page <= item_len_of_decode_rank_page # This means slice_lens_per_page <= item_len_of_decode_rank_page
if slice_lens_per_page > item_len_of_decode_rank_page: if slice_lens_per_page > item_len_of_decode_rank_page:
logger.error( logger.error(
...@@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager): ...@@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager):
return -1 return -1
layer_transfer_params.append( layer_transfer_params.append(
( (
self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads) self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[ dst_kv_ptrs[layer_id],
layer_id item_len_of_prefill_rank_page,
], # Decode base ptr (for its slice for this layer) item_len_of_decode_rank_page,
item_len_of_prefill_rank_page, # Prefill page size (all heads)2048 src_slice_offset,
item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024 dst_slice_offset,
src_slice_offset, # Offset to slice data in Prefill page slice_lens_per_page,
dst_slice_offset, # Offset to slice data in Decode page
slice_lens_per_page, # Length of slice data per page (actual data to send)
) )
) )
...@@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager): ...@@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager):
prefill_page_idx = int(prefill_kv_indices[i]) prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i]) decode_page_idx = int(dst_kv_indices[i])
# Get the starting memory address for the current source and destination pages # Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
# Iterate through each valid token slot within the current page # Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size): for token_slot_in_page in range(page_size):
# Calculate start address of the current token slot # Calculate the start address of the current token slot
src_token_slot_start_addr = ( src_token_slot_start_addr = (
src_page_start_addr src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill + token_slot_in_page * bytes_per_token_on_prefill
...@@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager):
+ token_slot_in_page * bytes_per_token_on_decode + token_slot_in_page * bytes_per_token_on_decode
) )
# Calculate final source and destination addresses by applying head-slice offsets # Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_offset src_slice_addr = src_token_slot_start_addr + src_offset
dst_slice_addr = dst_token_slot_start_addr + dst_offset dst_slice_addr = dst_token_slot_start_addr + dst_offset
...@@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_aux( ret = self.send_aux(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_aux_index, kv_chunk.prefill_aux_index,
self.decode_kv_args_table[ target_rank_registration_info.dst_aux_ptrs,
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index, req.dst_aux_index,
) )
polls.append(True if ret == 0 else False) polls.append(True if ret == 0 else False)
...@@ -675,17 +672,17 @@ class MooncakeKVManager(BaseKVManager): ...@@ -675,17 +672,17 @@ class MooncakeKVManager(BaseKVManager):
prefill_rank = int(prefill_rank.decode("ascii")) prefill_rank = int(prefill_rank.decode("ascii"))
if status == KVPoll.Success: if status == KVPoll.Success:
# record arrived prefill_rank if bootstrap_room in self.request_status:
self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank) self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
expected_prefill_num = self.required_prefill_info_num_map[ expected_response_num = (
bootstrap_room self.required_prefill_response_num_table[bootstrap_room]
] )
arrived_prefill_num = len( arrived_response_num = len(
self.decode_kv_arrive_state[bootstrap_room] self.prefill_response_tracker[bootstrap_room]
) )
if ( if (
self.is_mla_backend self.is_mla_backend
or arrived_prefill_num == expected_prefill_num or arrived_response_num == expected_response_num
): ):
self.update_status(bootstrap_room, KVPoll.Success) self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed: elif status == KVPoll.Failed:
...@@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender): ...@@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
self.aux_index = None self.aux_index = None
self.bootstrap_server_url = bootstrap_addr self.bootstrap_server_url = bootstrap_addr
self.conclude_state = None self.conclude_state = None
self.init_time = None self.init_time = time.time()
# inner state # inner state
self.curr_idx = 0 self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
self.aux_index = aux_index self.aux_index = aux_index
self.init_time = time.time()
def send( def send(
self, self,
...@@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) )
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_info_num = 1 self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
if not self.kv_mgr.is_mla_backend: if not self.kv_mgr.is_mla_backend:
...@@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.required_dst_info_num = ( self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
) )
self.required_prefill_info_num = 1 self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
else: else:
if not self.kv_mgr.is_mla_backend: if not self.kv_mgr.is_mla_backend:
...@@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly # or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_info_num = ( self.required_prefill_response_num = (
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
) )
...@@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
else: else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = ( self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_info_num self.required_prefill_response_num
) )
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
...@@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
if bootstrap_info is not None: if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend: if self.kv_mgr.is_mla_backend:
# MLA :select one prefill rank as real rank # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool( bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None or self.target_tp_rank is None
) )
else: else:
# no-MLA:select all prefill ranks # For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False bootstrap_info["is_dummy"] = False
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
...@@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
def clear(self) -> None: def clear(self) -> None:
if self.bootstrap_room in self.kv_mgr.request_status: if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room) self.kv_mgr.request_status.pop(self.bootstrap_room)
self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room)
self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room) if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# Explicitly set the status to failure since this request has failed in another rank # Explicitly set the status to failure since this request has failed in another rank
......
...@@ -97,13 +97,19 @@ class MooncakeTransferEngine: ...@@ -97,13 +97,19 @@ class MooncakeTransferEngine:
peer_buffer_addresses: List[int], peer_buffer_addresses: List[int],
lengths: List[int], lengths: List[int],
) -> int: ) -> int:
"""Synchronously transfer data to the specified address.""" """Synchronously transfer data to the specified addresses in batches."""
try: try:
ret = self.engine.batch_transfer_sync_write( ret = self.engine.batch_transfer_sync_write(
session_id, buffers, peer_buffer_addresses, lengths session_id, buffers, peer_buffer_addresses, lengths
) )
except Exception: except Exception:
ret = -1 ret = -1
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
if not hasattr(self.engine, "batch_transfer_sync_write"):
raise RuntimeError(
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
)
if ret < 0: if ret < 0:
logger.debug( logger.debug(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment