Unverified Commit d98a4913 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Refactor parallel sizes and add pp support for mooncake (#8571)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 08f8f490
...@@ -25,10 +25,13 @@ class KVArgs: ...@@ -25,10 +25,13 @@ class KVArgs:
gpu_id: int gpu_id: int
# for different tp # for different tp
decode_tp_size: int decode_tp_size: int
# for pp prefill
prefill_pp_size: int
kv_head_num: int kv_head_num: int
page_size: int page_size: int
# for pp prefill
prefill_pp_size: int
pp_rank: int
# for system dp
system_dp_rank: int
class KVPoll: class KVPoll:
......
...@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
...@@ -184,9 +185,13 @@ class DecodePreallocQueue: ...@@ -184,9 +185,13 @@ class DecodePreallocQueue:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class() kv_args = kv_args_class()
attn_tp_size = self.tp_size // self.dp_size attn_tp_size = get_attention_tp_size()
kv_args.engine_rank = self.tp_rank % (attn_tp_size) kv_args.engine_rank = self.tp_rank % (attn_tp_size)
kv_args.decode_tp_size = attn_tp_size kv_args.decode_tp_size = attn_tp_size
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
kv_args.pp_rank = 0
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.prefill_pp_size = self.prefill_pp_size kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos() self.token_to_kv_pool.get_contiguous_buf_infos()
......
...@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import ( ...@@ -34,6 +34,12 @@ from sglang.srt.disaggregation.common.utils import (
) )
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
format_tcp_address, format_tcp_address,
...@@ -113,7 +119,7 @@ class KVArgsRegisterInfo: ...@@ -113,7 +119,7 @@ class KVArgsRegisterInfo:
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_tp_rank: int dst_tp_rank: int
dst_tp_size: int dst_attn_tp_size: int
dst_kv_item_len: int dst_kv_item_len: int
@classmethod @classmethod
...@@ -126,7 +132,7 @@ class KVArgsRegisterInfo: ...@@ -126,7 +132,7 @@ class KVArgsRegisterInfo:
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")), dst_tp_rank=int(msg[6].decode("ascii")),
dst_tp_size=int(msg[7].decode("ascii")), dst_attn_tp_size=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")), dst_kv_item_len=int(msg[8].decode("ascii")),
) )
...@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager): ...@@ -147,13 +153,18 @@ class MooncakeKVManager(BaseKVManager):
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size self.attn_tp_size = get_attention_tp_size()
self.dp_size = server_args.dp_size self.attn_tp_rank = get_attention_tp_rank()
self.enable_dp_attention = server_args.enable_dp_attention self.attn_dp_size = get_attention_dp_size()
if not server_args.enable_dp_attention and server_args.dp_size != 1: self.attn_dp_rank = get_attention_dp_rank()
raise ValueError( self.system_dp_size = (
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode." 1 if server_args.enable_dp_attention else server_args.dp_size
) )
self.system_dp_rank = (
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
)
self.pp_size = server_args.pp_size
self.pp_rank = self.kv_args.pp_rank
self.request_status: Dict[int, KVPoll] = {} self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
...@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -221,8 +232,9 @@ class MooncakeKVManager(BaseKVManager):
) )
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {} self.prefill_attn_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {} self.prefill_dp_size_table: Dict[str, int] = {}
self.prefill_pp_size_table: Dict[str, int] = {}
# If a timeout happens on the decode side, it means decode instances # If a timeout happens on the decode side, it means decode instances
# fail to receive the KV Cache transfer done signal after bootstrapping. # fail to receive the KV Cache transfer done signal after bootstrapping.
# These timeout requests should be aborted to release the tree cache. # These timeout requests should be aborted to release the tree cache.
...@@ -296,15 +308,53 @@ class MooncakeKVManager(BaseKVManager): ...@@ -296,15 +308,53 @@ class MooncakeKVManager(BaseKVManager):
prefill_kv_indices, dst_kv_indices prefill_kv_indices, dst_kv_indices
) )
num_layers = len(self.kv_args.kv_data_ptrs) layers_params = None
layers_params = [
( # pp is not supported on the decode side yet
self.kv_args.kv_data_ptrs[layer_id], if self.is_mla_backend:
dst_kv_ptrs[layer_id], src_kv_ptrs = self.kv_args.kv_data_ptrs
self.kv_args.kv_item_lens[layer_id], layers_per_pp_stage = len(src_kv_ptrs)
) start_layer = self.pp_rank * layers_per_pp_stage
for layer_id in range(num_layers) end_layer = start_layer + layers_per_pp_stage
] dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_kv_ptrs[layer_id],
dst_kv_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
]
else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
]
assert layers_params is not None
# Worker function for processing a single layer # Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
...@@ -343,7 +393,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -343,7 +393,7 @@ class MooncakeKVManager(BaseKVManager):
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
dst_tp_rank: int, dst_tp_rank: int,
dst_tp_size: int, dst_attn_tp_size: int,
dst_kv_item_len: int, dst_kv_item_len: int,
executor: concurrent.futures.ThreadPoolExecutor, executor: concurrent.futures.ThreadPoolExecutor,
): ):
...@@ -356,23 +406,22 @@ class MooncakeKVManager(BaseKVManager): ...@@ -356,23 +406,22 @@ class MooncakeKVManager(BaseKVManager):
This may introduce performance overhead (increased TTFT) for long sequences. This may introduce performance overhead (increased TTFT) for long sequences.
""" """
# Extract configuration # Extract configuration
local_tp_size = self.tp_size // self.dp_size local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
src_kv_item_len = self.kv_args.kv_item_lens[0] src_kv_item_len = self.kv_args.kv_item_lens[0]
dst_tp_rank_in_group = dst_tp_rank % dst_tp_size dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
num_kv_heads = self.kv_args.kv_head_num num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size page_size = self.kv_args.page_size
# Calculate head distribution # Calculate head distribution
src_heads_per_rank = num_kv_heads src_heads_per_rank = num_kv_heads
dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
bytes_per_head_slice_to_send = ( bytes_per_head_slice_to_send = (
dst_kv_item_len // page_size // dst_heads_per_rank dst_kv_item_len // page_size // dst_heads_per_rank
) )
# Determine slicing parameters based on TP configuration # Determine slicing parameters based on TP configuration
if local_tp_size > dst_tp_size: if self.attn_tp_size > dst_attn_tp_size:
# Send KVCache from multiple prefill instances to 1 decode instance # Send KVCache from multiple prefill instances to 1 decode instance
src_head_start_offset = 0 src_head_start_offset = 0
num_heads_to_send = src_heads_per_rank num_heads_to_send = src_heads_per_rank
...@@ -383,35 +432,55 @@ class MooncakeKVManager(BaseKVManager): ...@@ -383,35 +432,55 @@ class MooncakeKVManager(BaseKVManager):
num_heads_to_send = dst_heads_per_rank num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0 dst_head_start_offset = 0
layers_params = [] # pp is not supported on the decode side yet
for layer_id in range(num_layers): num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
# Calculate precise byte offset and length for the sub-slice within the token src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send layers_per_pp_stage = len(src_k_ptrs)
heads_bytes_per_token_to_send = ( start_layer = self.pp_rank * layers_per_pp_stage
num_heads_to_send * bytes_per_head_slice_to_send end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
logger.error(
f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
f"target token slot size ({dst_kv_item_len // page_size})"
) )
return -1
# Sanity check: The data sub-slice to be sent should fit into the dst buffer. layers_params = [
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size) (
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size): src_k_ptrs[layer_id],
logger.error( dst_k_ptrs[layer_id],
f"[{mooncake_session_id}] Layer {layer_id}: " src_kv_item_len,
f"slice size ({heads_bytes_per_token_to_send}) exceeds " dst_kv_item_len,
f"target token slot size ({dst_kv_item_len // page_size})" src_head_slice_offset,
) dst_head_slice_offset,
return -1 heads_bytes_per_token_to_send,
layers_params.append(
(
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
) )
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_per_pp_stage)
]
def process_layer_tp_aware(layer_params): def process_layer_tp_aware(layer_params):
( (
...@@ -562,9 +631,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -562,9 +631,9 @@ class MooncakeKVManager(BaseKVManager):
target_rank_registration_info: KVArgsRegisterInfo = ( target_rank_registration_info: KVArgsRegisterInfo = (
self.decode_kv_args_table[req.mooncake_session_id] self.decode_kv_args_table[req.mooncake_session_id]
) )
local_tp_size = self.tp_size // self.dp_size
if self.is_mla_backend or ( if self.is_mla_backend or (
local_tp_size == target_rank_registration_info.dst_tp_size self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
): ):
ret = self.send_kvcache( ret = self.send_kvcache(
req.mooncake_session_id, req.mooncake_session_id,
...@@ -580,7 +649,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -580,7 +649,7 @@ class MooncakeKVManager(BaseKVManager):
target_rank_registration_info.dst_kv_ptrs, target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice, chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank, target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_tp_size, target_rank_registration_info.dst_attn_tp_size,
target_rank_registration_info.dst_kv_item_len, target_rank_registration_info.dst_kv_item_len,
executor, executor,
) )
...@@ -863,11 +932,16 @@ class MooncakeKVManager(BaseKVManager): ...@@ -863,11 +932,16 @@ class MooncakeKVManager(BaseKVManager):
url = f"http://{bootstrap_server_url}/route" url = f"http://{bootstrap_server_url}/route"
payload = { payload = {
"role": "Prefill", "role": "Prefill",
"tp_size": self.tp_size, "attn_tp_size": self.attn_tp_size,
"dp_size": self.dp_size, "attn_tp_rank": self.attn_tp_rank,
"attn_dp_size": self.attn_dp_size,
"attn_dp_rank": self.attn_dp_rank,
"pp_size": self.pp_size,
"pp_rank": self.pp_rank,
"system_dp_size": self.system_dp_size,
"system_dp_rank": self.system_dp_rank,
"rank_ip": self.local_ip, "rank_ip": self.local_ip,
"rank_port": self.rank_port, "rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
} }
try: try:
...@@ -890,10 +964,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -890,10 +964,12 @@ class MooncakeKVManager(BaseKVManager):
] ]
for k in keys_to_remove: for k in keys_to_remove:
del self.connection_pool[k] del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_tp_size_table: if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
del self.prefill_tp_size_table[failed_bootstrap_addr] del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table: if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr] del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get( possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, [] failed_bootstrap_addr, []
...@@ -915,7 +991,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -915,7 +991,7 @@ class MooncakeKVManager(BaseKVManager):
self.update_status(room, KVPoll.Failed) self.update_status(room, KVPoll.Failed)
affected_rooms.append(room) affected_rooms.append(room)
logger.error( logger.error(
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests" f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
) )
...@@ -1042,10 +1118,16 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1042,10 +1118,16 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = ( (
self._get_prefill_parallel_info_from_server() self.prefill_attn_tp_size,
) self.prefill_dp_size,
if self.prefill_tp_size is None or self.prefill_dp_size is None: self.prefill_pp_size,
) = self._get_prefill_parallel_info_from_server()
if (
self.prefill_attn_tp_size is None
or self.prefill_dp_size is None
or self.prefill_pp_size is None
):
self.kv_mgr.record_failure( self.kv_mgr.record_failure(
self.bootstrap_room, self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
...@@ -1054,43 +1136,47 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1054,43 +1136,47 @@ class MooncakeKVReceiver(BaseKVReceiver):
return return
else: else:
logger.debug( logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}" f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
) )
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size self.prefill_attn_tp_size
) )
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = ( self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size self.prefill_dp_size
) )
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
self.prefill_pp_size
)
else: else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[ self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
self.bootstrap_addr self.bootstrap_addr
] ]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[ self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr self.bootstrap_addr
] ]
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to # Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA. # have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = ( self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) )
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_response_num = 1 self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend: if not self.kv_mgr.is_mla_backend:
logger.warning_once( logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. " "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
) )
self.target_tp_rank = ( self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) ) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
self.required_dst_info_num = ( self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
) )
self.required_prefill_response_num = 1 self.required_prefill_response_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
...@@ -1103,10 +1189,10 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1103,10 +1189,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_ranks = [ self.target_tp_ranks = [
rank rank
for rank in range( for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank) (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1) (self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank), * (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
) )
] ]
...@@ -1116,7 +1202,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1116,7 +1202,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_response_num = ( self.required_prefill_response_num = (
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) )
if self.data_parallel_rank is not None: if self.data_parallel_rank is not None:
...@@ -1136,31 +1222,31 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1136,31 +1222,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
if bootstrap_key not in self.kv_mgr.connection_pool: if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = [] bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks: for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server( for target_pp_rank in range(self.prefill_pp_size):
target_tp_rank, bootstrap_info = self._get_bootstrap_info_from_server(
self.target_dp_group, target_tp_rank, self.target_dp_group, target_pp_rank
) )
if bootstrap_info is not None: if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend: if self.kv_mgr.is_mla_backend:
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks # For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool( bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None or self.target_tp_rank is None
)
else:
# For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
) )
bootstrap_infos.append(bootstrap_info)
else: else:
# For non-MLA: all target_tp_ranks are selected real ranks self.kv_mgr.record_failure(
bootstrap_info["is_dummy"] = False self.bootstrap_room,
logger.debug( f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" )
) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
bootstrap_infos.append(bootstrap_info) return
else:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
...@@ -1174,10 +1260,12 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1174,10 +1260,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room) self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): def _get_bootstrap_info_from_server(
self, engine_rank, target_dp_group, target_pp_rank
):
"""Fetch the bootstrap info from the bootstrap server.""" """Fetch the bootstrap info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
response = requests.get(url, timeout=5) response = requests.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:
bootstrap_info = response.json() bootstrap_info = response.json()
...@@ -1191,24 +1279,28 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1191,24 +1279,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}") logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None return None
def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]: def _get_prefill_parallel_info_from_server(
self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Fetch the prefill parallel info from the bootstrap server.""" """Fetch the prefill parallel info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
prefill_parallel_info = response.json() prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int( return (
prefill_parallel_info["prefill_dp_size"] int(prefill_parallel_info["prefill_attn_tp_size"]),
int(prefill_parallel_info["prefill_dp_size"]),
int(prefill_parallel_info["prefill_pp_size"]),
) )
else: else:
logger.error( logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}" f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
) )
return None, None return None, None, None
except Exception as e: except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None, None return None, None, None
def _register_kv_args(self): def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
...@@ -1218,11 +1310,11 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1218,11 +1310,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank tp_rank = self.kv_mgr.kv_args.engine_rank
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
dst_tp_rank = str(tp_rank).encode("ascii") dst_tp_rank = str(tp_rank).encode("ascii")
dst_tp_size = str(tp_size).encode("ascii") dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii")
sock, lock = self._connect_to_bootstrap_server(bootstrap_info) sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
...@@ -1236,7 +1328,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1236,7 +1328,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
packed_kv_data_ptrs, packed_kv_data_ptrs,
packed_aux_data_ptrs, packed_aux_data_ptrs,
dst_tp_rank, dst_tp_rank,
dst_tp_size, dst_attn_tp_size,
dst_kv_item_len, dst_kv_item_len,
] ]
) )
...@@ -1347,10 +1439,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1347,10 +1439,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict() self.store = dict()
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self._setup_routes() self._setup_routes()
self.tp_size = None self.pp_size = None
self.attn_tp_size = None
self.dp_size = None self.dp_size = None
self.tp_size_per_dp_rank = None self.prefill_port_table: Dict[
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {} int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
] = {}
# Start bootstrap server # Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True) self.thread = threading.Thread(target=self._run_server, daemon=True)
...@@ -1380,37 +1474,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1380,37 +1474,45 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_put(self, request: web.Request): async def _handle_route_put(self, request: web.Request):
data = await request.json() data = await request.json()
role = data["role"] role = data["role"]
tp_size = data["tp_size"] attn_tp_size = data["attn_tp_size"]
dp_size = data["dp_size"] attn_tp_rank = data["attn_tp_rank"]
attn_dp_size = data["attn_dp_size"]
attn_dp_rank = data["attn_dp_rank"]
pp_size = data["pp_size"]
pp_rank = data["pp_rank"]
system_dp_size = data["system_dp_size"]
system_dp_rank = data["system_dp_rank"]
rank_ip = data["rank_ip"] rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"]) rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None: if self.attn_tp_size is None:
self.tp_size = tp_size self.attn_tp_size = attn_tp_size
if self.dp_size is None: if self.dp_size is None:
self.dp_size = dp_size self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
tp_size_per_dp_rank = tp_size // dp_size if self.pp_size is None:
if self.tp_size_per_dp_rank is None: self.pp_size = pp_size
self.tp_size_per_dp_rank = tp_size_per_dp_rank
if role == "Prefill": if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank if system_dp_size == 1:
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank dp_group = attn_dp_rank
else:
dp_group = system_dp_rank
# Add lock to make sure thread-safe # Add lock to make sure thread-safe
async with self.lock: async with self.lock:
if dp_group not in self.prefill_port_table: if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {} self.prefill_port_table[dp_group] = {}
if attn_tp_rank not in self.prefill_port_table[dp_group]:
self.prefill_port_table[dp_group][attn_tp_rank] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = { self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
"rank_ip": rank_ip, "rank_ip": rank_ip,
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
...@@ -1418,14 +1520,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1418,14 +1520,20 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_get(self, request: web.Request): async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank") engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group") target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group: target_pp_rank = request.query.get("target_pp_rank")
if not engine_rank or not target_dp_group or not target_pp_rank:
return web.Response(text="Missing inputs for bootstrap server.", status=400) return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size # Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1: if (
int(engine_rank) == -1
and int(target_dp_group) == -1
and int(target_pp_rank) == -1
):
prefill_parallel_info = { prefill_parallel_info = {
"prefill_tp_size": self.tp_size, "prefill_attn_tp_size": self.attn_tp_size,
"prefill_dp_size": self.dp_size, "prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
} }
return web.json_response(prefill_parallel_info, status=200) return web.json_response(prefill_parallel_info, status=200)
...@@ -1433,7 +1541,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -1433,7 +1541,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
async with self.lock: async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][ bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank) int(engine_rank)
] ][int(target_pp_rank)]
if bootstrap_info is not None: if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200) return web.json_response(bootstrap_info, status=200)
......
...@@ -103,6 +103,8 @@ class PrefillBootstrapQueue: ...@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class() kv_args = kv_args_class()
kv_args.engine_rank = self.tp_rank kv_args.engine_rank = self.tp_rank
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size kv_args.prefill_pp_size = self.pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment