"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5c4ea00de772f9af456e68f30f830c7d7a158846"
Unverified Commit 9c339d6b authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Extract the PP transfer layer calculate logic from Mooncake to Common backend (#10565)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent e23e280e
...@@ -95,14 +95,6 @@ class CommonKVManager(BaseKVManager): ...@@ -95,14 +95,6 @@ class CommonKVManager(BaseKVManager):
def _bind_server_socket(self): def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
...@@ -156,6 +148,33 @@ class CommonKVManager(BaseKVManager): ...@@ -156,6 +148,33 @@ class CommonKVManager(BaseKVManager):
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
def get_mha_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
num_kv_layers = len(src_kv_ptrs) // 2
end_layer = start_layer + num_kv_layers
dst_num_total_layers = len(dst_kv_ptrs) // 2
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
layers_current_pp_stage = len(src_k_ptrs)
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
def get_mla_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(src_kv_ptrs)
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
layers_current_pp_stage = len(src_kv_ptrs)
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
class CommonKVSender(BaseKVSender): class CommonKVSender(BaseKVSender):
......
...@@ -264,12 +264,10 @@ class MooncakeKVManager(CommonKVManager): ...@@ -264,12 +264,10 @@ class MooncakeKVManager(CommonKVManager):
layers_params = None layers_params = None
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs = self.kv_args.kv_data_ptrs src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
layers_per_pp_stage = len(src_kv_ptrs) self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
...@@ -277,18 +275,12 @@ class MooncakeKVManager(CommonKVManager): ...@@ -277,18 +275,12 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs[layer_id], dst_kv_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
else: else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
dst_num_total_layers = num_kv_layers * self.pp_size self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers] )
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
...@@ -296,14 +288,14 @@ class MooncakeKVManager(CommonKVManager): ...@@ -296,14 +288,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs[layer_id], dst_k_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
dst_v_ptrs[layer_id], dst_v_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
assert layers_params is not None assert layers_params is not None
...@@ -401,18 +393,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -401,18 +393,9 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0 dst_head_start_offset = 0
# pp is not supported on the decode side yet src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2 self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
dst_num_total_layers = num_kv_layers * self.pp_size )
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token # Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
...@@ -438,7 +421,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -438,7 +421,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
...@@ -449,7 +432,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -449,7 +432,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_per_pp_stage) for layer_id in range(layers_current_pp_stage)
] ]
def process_layer_tp_aware(layer_params): def process_layer_tp_aware(layer_params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment