Unverified Commit c5f10cc1 authored by ZhengHongming888's avatar ZhengHongming888 Committed by GitHub
Browse files

add cpu option for p/d in nixl_connector (#28356)


Signed-off-by: default avatarHongming Zheng <hongming.zheng@intel.com>
parent d1431523
...@@ -91,6 +91,7 @@ _NIXL_SUPPORTED_DEVICE = { ...@@ -91,6 +91,7 @@ _NIXL_SUPPORTED_DEVICE = {
), ),
"tpu": ("cpu",), "tpu": ("cpu",),
"xpu": ("cpu",), "xpu": ("cpu",),
"cpu": ("cpu",),
} }
# support for oot platform by providing mapping in current_platform # support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
...@@ -348,7 +349,13 @@ class NixlConnectorScheduler: ...@@ -348,7 +349,13 @@ class NixlConnectorScheduler:
+ vllm_config.parallel_config.data_parallel_rank + vllm_config.parallel_config.data_parallel_rank
) )
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" if current_platform.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
logger.info("Initializing NIXL Scheduler %s", engine_id) logger.info("Initializing NIXL Scheduler %s", engine_id)
# Background thread for handling new handshake requests. # Background thread for handling new handshake requests.
...@@ -820,7 +827,11 @@ class NixlConnectorWorker: ...@@ -820,7 +827,11 @@ class NixlConnectorWorker:
# cpu kv buffer for xfer # cpu kv buffer for xfer
# used when device memory can not be registered under nixl # used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.host_xfer_buffers: dict[str, torch.Tensor] = {}
if self.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = self.kv_buffer_device == "cpu" self.use_host_buffer = self.kv_buffer_device == "cpu"
# support for oot platform which can't register nixl memory # support for oot platform which can't register nixl memory
# type based on kv_buffer_device # type based on kv_buffer_device
nixl_memory_type = current_platform.get_nixl_memory_type() nixl_memory_type = current_platform.get_nixl_memory_type()
...@@ -1021,6 +1032,9 @@ class NixlConnectorWorker: ...@@ -1021,6 +1032,9 @@ class NixlConnectorWorker:
# Set a no-op if the host buffer is not cpu. # Set a no-op if the host buffer is not cpu.
if self.kv_buffer_device != "cpu": if self.kv_buffer_device != "cpu":
return return
# Set a no-op if self.device_type is 'cpu'.
if self.device_type == "cpu":
return
assert self.use_host_buffer assert self.use_host_buffer
self.copy_blocks = copy_operation self.copy_blocks = copy_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment