Unverified Commit 25826835 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[PD] Skip `tp_size` exchange with rank0 (#19413)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 754b00ed
...@@ -7,6 +7,8 @@ from collections import defaultdict ...@@ -7,6 +7,8 @@ from collections import defaultdict
from typing import Optional from typing import Optional
from unittest.mock import patch from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker) NixlConnectorWorker)
...@@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication. # Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency) time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by # These should've been done in register_kv_caches(), called by
...@@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
num_blocks=1, num_blocks=1,
tp_size=1,
block_len=self.block_len, block_len=self.block_len,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
)) ),
remote_tp_size=remote_tp_size)
return {0: remote_agent_name} return {0: remote_agent_name}
...@@ -233,6 +236,8 @@ class TestNixlHandshake: ...@@ -233,6 +236,8 @@ class TestNixlHandshake:
"localhost", "localhost",
"remote_port": "remote_port":
1234, 1234,
"remote_tp_size":
1,
}) })
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
...@@ -259,13 +264,23 @@ class TestNixlHandshake: ...@@ -259,13 +264,23 @@ class TestNixlHandshake:
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper) FakeNixlWrapper)
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
(1, 1),
(2, 1),
(4, 2),
(4, 4),
])
def test_async_load_kv( def test_async_load_kv(
self, self,
# dist_init is a fixture that initializes the distributed environment. # Fixture that initializes the distributed environment.
dist_init): dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size):
"""Test that NixlConnector's start_load_kv should be non-blocking.""" """Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
...@@ -280,6 +295,7 @@ class TestNixlHandshake: ...@@ -280,6 +295,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker.REMOTE_ENGINE_ID, FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost", "remote_host": "localhost",
"remote_port": 1234, "remote_port": 1234,
"remote_tp_size": prefill_tp_size,
}) })
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
...@@ -329,6 +345,7 @@ class TestNixlHandshake: ...@@ -329,6 +345,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker.REMOTE_ENGINE_ID, FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost", "remote_host": "localhost",
"remote_port": 1234, "remote_port": 1234,
"remote_tp_size": 1,
}) })
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
......
...@@ -62,7 +62,6 @@ class NixlAgentMetadata( ...@@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata: bytes agent_metadata: bytes
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
num_blocks: int num_blocks: int
tp_size: int
block_len: int block_len: int
attn_backend_name: str attn_backend_name: str
...@@ -73,7 +72,8 @@ class ReqMeta: ...@@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids: list[int] remote_block_ids: list[int]
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_engine_id: EngineId remote_engine_id: str
tp_size: int
class NixlConnectorMetadata(KVConnectorMetadata): class NixlConnectorMetadata(KVConnectorMetadata):
...@@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
remote_engine_id=kv_transfer_params["remote_engine_id"], remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"], remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
) )
...@@ -330,7 +332,7 @@ class NixlConnectorScheduler: ...@@ -330,7 +332,7 @@ class NixlConnectorScheduler:
remote_engine_id=self.engine_id, remote_engine_id=self.engine_id,
remote_host=self.side_channel_host, remote_host=self.side_channel_host,
remote_port=self.side_channel_port, remote_port=self.side_channel_port,
) tp_size=self.vllm_config.parallel_config.tensor_parallel_size)
class NixlConnectorWorker: class NixlConnectorWorker:
...@@ -473,7 +475,8 @@ class NixlConnectorWorker: ...@@ -473,7 +475,8 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg) "Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data)) sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance.""" """Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -482,7 +485,7 @@ class NixlConnectorWorker: ...@@ -482,7 +485,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd # a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler. # or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: def handshake(path: str, rank: int) -> str:
# Send query for the request. # Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock: with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG) sock.send(GET_META_MSG)
...@@ -492,33 +495,25 @@ class NixlConnectorWorker: ...@@ -492,33 +495,25 @@ class NixlConnectorWorker:
got_metadata_time = time.perf_counter() got_metadata_time = time.perf_counter()
# Register Remote agent. # Register Remote agent.
remote_agent_name = self.add_remote_agent(metadata, rank) remote_agent_name = self.add_remote_agent(
metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter() setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s", logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time) got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s", logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time) setup_agent_time - got_metadata_time)
return metadata, remote_agent_name return remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
rank_to_agent_name: dict[int, str] = {}
metadata, rank_to_agent_name[0] = handshake(path, 0)
# Handshake only with the other TP remote the current local rank will # Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i. # pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank) path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", logger.debug("Querying metadata on path: %s at remote rank %s", path,
path, p_remote_rank) p_remote_rank)
_, rank_to_agent_name[p_remote_rank] = handshake( # Remote rank -> agent name.
path, p_remote_rank) return {p_remote_rank: handshake(path, p_remote_rank)}
return rank_to_agent_name
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
...@@ -645,7 +640,6 @@ class NixlConnectorWorker: ...@@ -645,7 +640,6 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len, block_len=self.block_len,
attn_backend_name=self.backend_name) attn_backend_name=self.backend_name)
ready_event = threading.Event() ready_event = threading.Event()
...@@ -659,7 +653,8 @@ class NixlConnectorWorker: ...@@ -659,7 +653,8 @@ class NixlConnectorWorker:
def add_remote_agent(self, def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata, nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0) -> str: remote_tp_rank: int = 0,
remote_tp_size: int = 1) -> str:
""" """
Add the remote NIXL agent and prepare the descriptors for reading cache Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote. blocks from remote.
...@@ -704,9 +699,9 @@ class NixlConnectorWorker: ...@@ -704,9 +699,9 @@ class NixlConnectorWorker:
return self._remote_agents[engine_id][remote_tp_rank] return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size: if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size assert self._tp_size[engine_id] == remote_tp_size
else: else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size self._tp_size[engine_id] = remote_tp_size
# We may eventually enable this after asserting equality in cache # We may eventually enable this after asserting equality in cache
# layout and close outputs. # layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name assert nixl_agent_meta.attn_backend_name == self.backend_name
...@@ -756,9 +751,7 @@ class NixlConnectorWorker: ...@@ -756,9 +751,7 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the # rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P). # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank = self.tp_rank // tp_ratio
# Only register the remote's descriptors if current rank pulls from it. # Only register the remote's descriptors if current rank pulls from it.
if p_remote_tp_rank == remote_tp_rank:
self.kv_caches_base_addr[ self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \ rank_offset = self.tp_rank % tp_ratio * self.block_len \
...@@ -917,7 +910,7 @@ class NixlConnectorWorker: ...@@ -917,7 +910,7 @@ class NixlConnectorWorker:
if fut is None: if fut is None:
fut = self._handshake_initiation_executor.submit( fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host, self._nixl_handshake, meta.remote_host,
meta.remote_port) meta.remote_port, meta.tp_size)
self._handshake_futures[remote_engine_id] = fut self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]], def done_callback(f: Future[dict[int, str]],
...@@ -957,13 +950,9 @@ class NixlConnectorWorker: ...@@ -957,13 +950,9 @@ class NixlConnectorWorker:
remote_block_ids=meta.remote_block_ids, remote_block_ids=meta.remote_block_ids,
) )
def _read_blocks( def _read_blocks(self, local_block_ids: list[int],
self, remote_block_ids: list[int], dst_engine_id: str,
local_block_ids: list[int], request_id: str):
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
):
# NOTE(rob): having the staging blocks be on the READER side is # NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors). # not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the # after we detect the txn is complete (which means we cannot make the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment