Unverified Commit cc3993b0 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology (#39529)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 50dd4cb4
...@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes(): ...@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
mock_thread.return_value.is_alive.return_value = False mock_thread.return_value.is_alive.return_value = False
worker.use_mla = True worker.use_mla = True
worker.kv_topo.is_mla = True worker.transfer_topo.is_mla = True
# MLA cache tensor: shape[-2] is the block size. # MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16) mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
...@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
# Override TP rank/size to simulate P TP=2 # Override TP rank/size to simulate P TP=2
prefill_worker.tp_rank = P_TP_RANK prefill_worker.tp_rank = P_TP_RANK
prefill_worker.tp_size = P_TP_SIZE prefill_worker.tp_size = P_TP_SIZE
# Update shared dict so kv_topo sees correct TP size
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
prefill_worker.kv_topo.tp_rank = P_TP_RANK prefill_worker.transfer_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_size = P_TP_SIZE
prefill_worker.kv_caches_base_addr = [0x1000] prefill_worker.kv_caches_base_addr = [0x1000]
prefill_worker.block_len_per_layer = [local_block_len] prefill_worker.block_len_per_layer = [local_block_len]
...@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
send_meta.ready.set() send_meta.ready.set()
# Compute target D ranks using the production code path # Compute target D ranks using the production code path
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size) target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size)
mock_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_socket.send_multipart = AsyncMock() mock_socket.send_multipart = AsyncMock()
......
...@@ -21,7 +21,7 @@ from vllm import LLM ...@@ -21,7 +21,7 @@ from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator, KVOutputAggregator,
TpKVTopology, TransferTopology,
get_current_attn_backend, get_current_attn_backend,
) )
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
...@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
test_shape = self.attn_backends[0].get_kv_cache_shape( test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
self.kv_topo = TpKVTopology( self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends, attn_backends=self.attn_backends,
tensor_shape=test_shape, tensor_shape=test_shape,
) )
self.compat_hash = compute_nixl_compatibility_hash( self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
) )
def _nixl_handshake( def _nixl_handshake(
...@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# Adjust remote block length metadata to satisfy heterogeneous TP # Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation. # invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer) remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size: if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller # P TP > D TP case, block_len of remote is smaller
remote_block_lens = [ remote_block_lens = [
...@@ -731,8 +732,9 @@ class TestNixlHandshake: ...@@ -731,8 +732,9 @@ class TestNixlHandshake:
assert set(remote_agents.keys()) == set(range(tp_ratio)) assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size remote_info = worker.transfer_topo.get_engine_info(remote_engine_id)
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id) assert remote_info.remote_tp_size == remote_tp_size
assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
...@@ -796,7 +798,7 @@ class TestNixlHandshake: ...@@ -796,7 +798,7 @@ class TestNixlHandshake:
(conn_p0.connector_worker, conn_p1.connector_worker) (conn_p0.connector_worker, conn_p1.connector_worker)
): ):
worker.world_size = p_tp_size worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size} worker.transfer_topo.tp_size = p_tp_size
worker.tp_rank = rank worker.tp_rank = rank
worker.use_mla = True worker.use_mla = True
...@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation( ...@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash( remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, remote_vllm_config,
decode_worker.backend_name, decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks, decode_worker.transfer_topo.cross_layers_blocks,
) )
prefill_block_size = config_overrides.get("block_size", 16) prefill_block_size = config_overrides.get("block_size", 16)
...@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
test_shape = backend.get_kv_cache_shape( test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
decode_worker.kv_topo = TpKVTopology( decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank, tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
block_size=decode_worker.block_size,
engine_id=decode_worker.engine_id, engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla, is_mla=decode_worker.use_mla,
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend], attn_backends=[backend],
tensor_shape=test_shape, tensor_shape=test_shape,
...@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash( decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config, decode_worker.vllm_config,
decode_worker.backend_name, decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks, decode_worker.transfer_topo.cross_layers_blocks,
) )
if error_scenario == "handshake_decode_error": if error_scenario == "handshake_decode_error":
......
...@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids( ...@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids(
remote_engine_id = "remote-engine" remote_engine_id = "remote-engine"
if has_mamba: if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio} worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely, # Mock transfer_topo: empty remote ranks skips the transfer machinery
# isolating the block-ID expansion logic. # entirely, isolating the block-ID expansion logic.
worker.kv_topo = MagicMock() worker.transfer_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = [] worker.transfer_topo.target_remote_ranks.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1 worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
...@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm(): ...@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True worker._has_mamba = True
worker._is_mamba_group = [False, True] worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1 worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1} worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100] worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling) # num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks worker.num_descs = 2 * num_blocks
...@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch(): ...@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True worker._has_mamba = True
worker._is_mamba_group = [False, True] worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio} worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100] worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800 worker.num_descs = 2 * num_blocks # 800
...@@ -532,15 +533,15 @@ def test_has_mamba_init( ...@@ -532,15 +533,15 @@ def test_has_mamba_init(
((9216, 524288), 4096, 131), ((9216, 524288), 4096, 131),
], ],
) )
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio): def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent. """Verify that compute_physical_blocks_per_logical is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why (e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine. _physical_blocks_per_logical must be stored per-engine.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio, compute_physical_blocks_per_logical,
) )
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio
...@@ -21,7 +21,7 @@ from vllm import envs ...@@ -21,7 +21,7 @@ from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId, EngineId,
TpKVTopology, TransferTopology,
get_current_attn_backend, get_current_attn_backend,
get_current_attn_backends, get_current_attn_backends,
) )
...@@ -764,13 +764,13 @@ class MooncakeConnectorWorker: ...@@ -764,13 +764,13 @@ class MooncakeConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self.transfer_topo = TransferTopology(
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size,
block_size=self.block_size,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=[backend], attn_backends=[backend],
) )
...@@ -911,7 +911,7 @@ class MooncakeConnectorWorker: ...@@ -911,7 +911,7 @@ class MooncakeConnectorWorker:
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
): ):
pending_reqs: dict[ReqId, SendBlockMeta] = {} pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) remote_tp_ranks = self.transfer_topo.handshake_target_ranks(meta.remote_tp_size)
if meta.remote_tp_rank not in remote_tp_ranks: if meta.remote_tp_rank not in remote_tp_ranks:
# This D worker does not pair with the P worker. # This D worker does not pair with the P worker.
msg = ( msg = (
...@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker: ...@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker:
seen_base_addresses = [] seen_base_addresses = []
self.block_len_per_layer = [] self.block_len_per_layer = []
split_k_and_v = self.kv_topo.split_k_and_v split_k_and_v = self.transfer_topo.split_k_and_v
tensor_size_bytes = None tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items(): for layer_name, cache_or_caches in kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
...@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker: ...@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker:
remote_engine_id: EngineId, remote_engine_id: EngineId,
pull_metas: dict[ReqId, PullReqMeta], pull_metas: dict[ReqId, PullReqMeta],
): ):
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( remote_tp_ranks = self.transfer_topo.handshake_target_ranks(
remote_engine_id self._tp_size[remote_engine_id]
) )
count = len(remote_tp_ranks) count = len(remote_tp_ranks)
logger.debug( logger.debug(
...@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker: ...@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker:
) )
def _producer_cache_is_replicated(self) -> bool: def _producer_cache_is_replicated(self) -> bool:
return self.kv_topo.replicates_kv_cache(self.engine_id) return self.transfer_topo.local_replicates_kv_cache
def _get_transfer_regions( def _get_transfer_regions(
self, base_addrs: list[int], block_lens: list[int] self, base_addrs: list[int], block_lens: list[int]
...@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker: ...@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker:
return _expand_transfer_regions( return _expand_transfer_regions(
base_addrs=base_addrs, base_addrs=base_addrs,
block_lens=block_lens, block_lens=block_lens,
is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first, is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first,
) )
def _get_sender_transfer_plan( def _get_sender_transfer_plan(
......
...@@ -151,7 +151,9 @@ def derive_mamba_conv_split( ...@@ -151,7 +151,9 @@ def derive_mamba_conv_split(
) )
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int: def compute_physical_blocks_per_logical(
ssm_sizes: tuple[int, ...], block_len: int
) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata. """Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we The remote engine's ratio is not sent directly in the handshake, so we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment