Commit ed83e246 authored by Anant Sharma's avatar Anant Sharma Committed by GitHub
Browse files

chore: update nixl github commit (#214)


Co-authored-by: default avatarPiotr Tarasiewicz <ptarasiewicz@nvidia.com>
parent 70266ec8
......@@ -157,7 +157,6 @@ ENV NIXL_PLUGIN_DIR=/usr/local/nixl/lib/x86_64-linux-gnu/plugins
RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/
RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl
......
......@@ -60,7 +60,7 @@ TENSORRTLLM_PIP_WHEEL_PATH=""
VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
NIXL_COMMIT=d7a2c571a60d76a3d6c8458140eaaa5025fa48c4
NIXL_COMMIT=f35faf8ba4e725f1724177d0772200481d1d3446
NIXL_REPO=ai-dynamo/nixl.git
get_options() {
......
......@@ -810,7 +810,7 @@ index 00000000..9b938039
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..9b757396
index 00000000..d972252a
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,400 @@
......@@ -828,7 +828,7 @@ index 00000000..9b757396
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ from nixl._api import nixl_agent as NixlWrapper
+ logger.info("NIXL is available")
+except ImportError:
+ logger.warning("NIXL is not available")
......@@ -916,10 +916,10 @@ index 00000000..9b757396
+ for agent_name in agent_names:
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+ for src_xfer_side_handle in self.src_xfer_side_handles.values():
+ self.nixl_wrapper.delete_xfer_side(src_xfer_side_handle)
+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
+ self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
......@@ -1092,9 +1092,9 @@ index 00000000..9b757396
+
+
+ logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
+ handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, staging_block_descs_ids,
+ handle = self.nixl_wrapper.make_prepped_xfer("WRITE", src_xfer_side_handle, staging_block_descs_ids,
+ dst_xfer_side_handle, dst_block_descs_ids,
+ notify_msg, "WRITE")
+ notify_msg)
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
......@@ -1137,9 +1137,9 @@ index 00000000..9b757396
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ handle = self.nixl_wrapper.initialize_xfer("WRITE", src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ notify_msg)
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
......@@ -1179,7 +1179,7 @@ index 00000000..9b757396
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side("", descs)
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs)
+
+ # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks
......@@ -1192,7 +1192,7 @@ index 00000000..9b757396
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
+
+ return agent_names
+
......@@ -1203,7 +1203,7 @@ index 00000000..9b757396
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ # self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment