"src/vscode:/vscode.git/clone" did not exist on "b8b0fd22b25c5ec3d57a7521b8bdc725504c07b3"
Unverified Commit 8f2cd177 authored by shaharmor98's avatar shaharmor98 Committed by GitHub
Browse files

add code pp support for nixl (#11375)


Signed-off-by: default avatarShahar Mor <smor@nvidia.com>
parent ab926dd6
...@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager): ...@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
logger.debug(f"sending kvcache to {peer_name} with notif {notif}") logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
# Make descs # Make descs
num_layers = len(self.kv_args.kv_data_ptrs) if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_kv_ptrs[layer_id],
dst_kv_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
]
src_addrs = [] src_addrs = []
dst_addrs = [] dst_addrs = []
for layer_id in range(num_layers): for src_ptr, dst_ptr, item_len in layers_params:
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
...@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager): ...@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0 dst_head_start_offset = 0
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
# Create transfer descriptors # Create transfer descriptors
src_addrs = [] src_addrs = []
dst_addrs = [] dst_addrs = []
...@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager): ...@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
bytes_per_token_on_prefill = src_kv_item_len // page_size bytes_per_token_on_prefill = src_kv_item_len // page_size
bytes_per_token_on_decode = dst_kv_item_len // page_size bytes_per_token_on_decode = dst_kv_item_len // page_size
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
# Calculate precise byte offset and length for the sub-slice within the token # Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
...@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager): ...@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
src_k_ptrs[layer_id], src_k_ptrs[layer_id],
dst_k_ptrs[layer_id], dst_k_ptrs[layer_id],
) )
for layer_id in range(len(src_k_ptrs)) for layer_id in range(layers_current_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
dst_v_ptrs[layer_id], dst_v_ptrs[layer_id],
) )
for layer_id in range(len(src_v_ptrs)) for layer_id in range(layers_current_pp_stage)
] ]
src_addrs = [] src_addrs = []
...@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager): ...@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
dst_aux_index: int, dst_aux_index: int,
notif: str, notif: str,
): ):
# Make descs src_addrs = []
aux_item_len = self.kv_args.aux_item_lens[0] dst_addrs = []
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len prefill_aux_ptrs = self.kv_args.aux_data_ptrs
) prefill_aux_item_lens = self.kv_args.aux_item_lens
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
src_addrs = [(prefill_aux_addr, aux_item_len, 0)] for i, _ in enumerate(dst_aux_ptrs):
dst_addrs = [(decode_aux_addr, aux_item_len, 0)] length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
src_addrs.append((src_addr, length, 0))
dst_addrs.append((dst_addr, length, 0))
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM") src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM") dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
# Transfer data # Transfer data
...@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager): ...@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
handles.append(kv_xfer_handle) handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data. # Only the last chunk we need to send the aux data.
if is_last: if is_last and self.pp_group.is_last_rank:
assert aux_index is not None assert aux_index is not None
aux_xfer_handle = self.send_aux( aux_xfer_handle = self.send_aux(
req.agent_name, req.agent_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment