"docs/vscode:/vscode.git/clone" did not exist on "3f6f6941598f669bf05447cc50018ef63cc7ab02"
Unverified Commit bb9b608c authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[PD][NIXL] Set is_sorted=False to fix NIXL_ERR_NOT_FOUND (#7330)

parent 69183f88
......@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors")
......@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")
......@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
)
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
......@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment