Unverified Commit bc3700e0 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent fd8afdf3
......@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
run_tests() {
......
......@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
......@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
block_len // (-tp_ratio) for block_len in remote_block_lens
]
elif remote_tp_size < self.world_size:
remote_block_lens = [
block_len * tp_ratio for block_len in remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake:
......@@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
......@@ -567,6 +596,171 @@ class TestNixlHandshake:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
......@@ -585,6 +779,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
......@@ -672,7 +869,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
......@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
# Mock register_kv_cache which registers local handle
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
......@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
......
......@@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
......@@ -209,12 +211,12 @@ class TpKVTopology:
"""
tp_rank: int
remote_tp_size: dict[str, int]
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
......@@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
)
return self.tp_size // remote_tp_size
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
) -> int:
"""
Calculate the block size ratio between local and remote TP.
"""
......@@ -279,19 +291,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_engine_id: EngineId,
) -> int:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
......@@ -299,24 +311,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
def get_target_remote_ranks(
self,
remote_tp_size: int,
) -> int:
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
def get_target_remote_rank_from_engine_id(
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
return self.get_target_remote_ranks(remote_tp_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment