"src/vscode:/vscode.git/clone" did not exist on "c8d86e9f0a2791eb0d08b2692803cab5ea7a35e2"
Unverified Commit bb9b608c authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[PD][NIXL] Set is_sorted=False to fix NIXL_ERR_NOT_FOUND (#7330)

parent 69183f88
...@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager): ...@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
): ):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}") logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
if not self.kv_descs: if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors") raise Exception("NIXL memory registration failed for kv tensors")
...@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager): ...@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
): ):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}") logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
if not self.aux_descs: if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors") raise Exception("NIXL memory registration failed for aux tensors")
...@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager): ...@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
logger.debug( logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
) )
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
# Transfer data # Transfer data
xfer_handle = self.agent.initialize_xfer( xfer_handle = self.agent.initialize_xfer(
"WRITE", "WRITE",
...@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager): ...@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
src_addrs = [(prefill_aux_addr, aux_item_len, 0)] src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
dst_addrs = [(decode_aux_addr, aux_item_len, 0)] dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True) src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
# Transfer data # Transfer data
xfer_handle = self.agent.initialize_xfer( xfer_handle = self.agent.initialize_xfer(
"WRITE", "WRITE",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment