Unverified Commit 1696c864 authored by Zhewen Li's avatar Zhewen Li Committed by GitHub
Browse files

[Bugfix][Mooncake] Fix thread-local CUDA context for NVLink transfers in _send_blocks (#39548)


Signed-off-by: default avatarZhewen Li <zhewenli@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 2ad10292
......@@ -41,6 +41,7 @@ from vllm.distributed.parallel_state import (
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
......@@ -645,6 +646,10 @@ class MooncakeConnectorWorker:
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
self.vllm_config = vllm_config
# Capture device BEFORE TransferEngine init — MNNVL's NVLink allocator
# may change the current CUDA device during engine.initialize().
self.device_id = torch.accelerator.current_device_index()
current_platform.set_device(self.device_id)
self.engine = TransferEngine()
self.hostname = get_ip()
......@@ -705,9 +710,12 @@ class MooncakeConnectorWorker:
# For kv_both, we will act both prefiller and decoder.
if not self.is_kv_consumer:
# Background threads for sending kvcaches to D.
# Each pool thread must be bound to the correct CUDA device
# because CUDA device selection is thread-local.
self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_sender_workers,
thread_name_prefix="vllm-mooncake-sender",
initializer=self._bind_sender_thread_device,
)
logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches",
......@@ -1193,6 +1201,12 @@ class MooncakeConnectorWorker:
return src_ptrs, dst_ptrs, lengths, err_reqs, err_msg
def _bind_sender_thread_device(self) -> None:
"""ThreadPoolExecutor initializer — binds each pool thread to the
correct CUDA device. CUDA device selection is thread-local, so
without this, NVLink transfers fail for TP ranks > 0."""
current_platform.set_device(self.device_id)
def _send_blocks(
self,
remote_session: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment