"csrc/vscode:/vscode.git/clone" did not exist on "38c498b8e3aaec95049f384edfc56ca12cbe1839"
Unverified Commit cc3993b0 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology (#39529)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 50dd4cb4
......@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
mock_thread.return_value.is_alive.return_value = False
worker.use_mla = True
worker.kv_topo.is_mla = True
worker.transfer_topo.is_mla = True
# MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
......@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
# Override TP rank/size to simulate P TP=2
prefill_worker.tp_rank = P_TP_RANK
prefill_worker.tp_size = P_TP_SIZE
# Update shared dict so kv_topo sees correct TP size
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
prefill_worker.kv_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_size = P_TP_SIZE
prefill_worker.kv_caches_base_addr = [0x1000]
prefill_worker.block_len_per_layer = [local_block_len]
......@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
send_meta.ready.set()
# Compute target D ranks using the production code path
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size)
target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size)
mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_socket.send_multipart = AsyncMock()
......
......@@ -21,7 +21,7 @@ from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
TransferTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
......@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
tensor_shape=test_shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
)
def _nixl_handshake(
......@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
......@@ -731,8 +732,9 @@ class TestNixlHandshake:
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
remote_info = worker.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size
assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
......@@ -796,7 +798,7 @@ class TestNixlHandshake:
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.transfer_topo.tp_size = p_tp_size
worker.tp_rank = rank
worker.use_mla = True
......@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)
prefill_block_size = config_overrides.get("block_size", 16)
......@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
block_size=decode_worker.block_size,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
tensor_shape=test_shape,
......@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error":
......
......@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids(
remote_engine_id = "remote-engine"
if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio}
worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely,
# isolating the block-ID expansion logic.
worker.kv_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1
# Mock transfer_topo: empty remote ranks skips the transfer machinery
# entirely, isolating the block-ID expansion logic.
worker.transfer_topo = MagicMock()
worker.transfer_topo.target_remote_ranks.return_value = []
worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
......@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
......@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
......@@ -532,15 +533,15 @@ def test_has_mamba_init(
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_physical_blocks_per_logical is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
_physical_blocks_per_logical must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
compute_physical_blocks_per_logical,
)
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio
......@@ -21,7 +21,7 @@ from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
TransferTopology,
get_current_attn_backend,
get_current_attn_backends,
)
......@@ -764,13 +764,13 @@ class MooncakeConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
)
......@@ -911,7 +911,7 @@ class MooncakeConnectorWorker:
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
):
pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
remote_tp_ranks = self.transfer_topo.handshake_target_ranks(meta.remote_tp_size)
if meta.remote_tp_rank not in remote_tp_ranks:
# This D worker does not pair with the P worker.
msg = (
......@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker:
seen_base_addresses = []
self.block_len_per_layer = []
split_k_and_v = self.kv_topo.split_k_and_v
split_k_and_v = self.transfer_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
......@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker:
remote_engine_id: EngineId,
pull_metas: dict[ReqId, PullReqMeta],
):
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
remote_engine_id
remote_tp_ranks = self.transfer_topo.handshake_target_ranks(
self._tp_size[remote_engine_id]
)
count = len(remote_tp_ranks)
logger.debug(
......@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker:
)
def _producer_cache_is_replicated(self) -> bool:
return self.kv_topo.replicates_kv_cache(self.engine_id)
return self.transfer_topo.local_replicates_kv_cache
def _get_transfer_regions(
self, base_addrs: list[int], block_lens: list[int]
......@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker:
return _expand_transfer_regions(
base_addrs=base_addrs,
block_lens=block_lens,
is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first,
is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first,
)
def _get_sender_transfer_plan(
......
......@@ -151,7 +151,9 @@ def derive_mamba_conv_split(
)
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int:
def compute_physical_blocks_per_logical(
ssm_sizes: tuple[int, ...], block_len: int
) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment