Unverified Commit bc3700e0 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent fd8afdf3
...@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh" ...@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=( configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
) )
run_tests() { run_tests() {
......
...@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake( def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
...@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent( # Adjust remote block length metadata to satisfy heterogeneous TP
NixlAgentMetadata( # invariants enforced during handshake validation.
engine_id=self.REMOTE_ENGINE_ID, remote_block_lens = list(self.block_len_per_layer)
agent_metadata=FakeNixlWrapper.AGENT_METADATA, tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
kv_caches_base_addr=[0], if remote_tp_size > self.world_size:
device_id=0, # P TP > D TP case, block_len of remote is smaller
num_blocks=1, remote_block_lens = [
block_lens=self.block_len_per_layer, block_len // (-tp_ratio) for block_len in remote_block_lens
# `self.kv_cache_layout` is only forced to HND when vllm engine ]
# is started. We mock HND here. elif remote_tp_size < self.world_size:
kv_cache_layout="HND", remote_block_lens = [
block_size=self.block_size, block_len * tp_ratio for block_len in remote_block_lens
), ]
remote_tp_size=remote_tp_size,
) # When remote tp_size > local tp_size, handshake with multiple
return {0: remote_agent_name} # remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake: class TestNixlHandshake:
...@@ -453,7 +476,13 @@ class TestNixlHandshake: ...@@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4 num_xfers = 4
while True: while True:
# For the same request_id, initiate multiple xfers across different # For the same request_id, initiate multiple xfers across different
...@@ -567,6 +596,171 @@ class TestNixlHandshake: ...@@ -567,6 +596,171 @@ class TestNixlHandshake:
return return
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
...@@ -585,6 +779,9 @@ class TestNixlHandshake: ...@@ -585,6 +779,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
total_reqs = 5 total_reqs = 5
for i in range(total_reqs): for i in range(total_reqs):
...@@ -672,7 +869,6 @@ class TestNixlHandshake: ...@@ -672,7 +869,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# mismatched layout is expected to fail # mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch( @patch(
...@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init): ...@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
): ):
worker._recving_transfers = {"req1": [123]} worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456 # Mock register_kv_cache which registers local handle
worker.dst_xfer_side_handles = {"engine1": 789} worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}} worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"] worker._registered_descs = ["desc1", "desc2"]
...@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init): ...@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once() mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123) mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2 assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1") mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2 assert mock_dereg.call_count == 2
......
...@@ -21,6 +21,8 @@ if TYPE_CHECKING: ...@@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
...@@ -209,12 +211,12 @@ class TpKVTopology: ...@@ -209,12 +211,12 @@ class TpKVTopology:
""" """
tp_rank: int tp_rank: int
remote_tp_size: dict[str, int] remote_tp_size: dict[EngineId, int]
is_mla: bool is_mla: bool
total_num_kv_heads: int total_num_kv_heads: int
attn_backend: type[AttentionBackend] attn_backend: type[AttentionBackend]
engine_id: str engine_id: EngineId
remote_block_size: dict[str, int] remote_block_size: dict[EngineId, int]
def __post_init__(self): def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
...@@ -256,18 +258,28 @@ class TpKVTopology: ...@@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP. Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`. groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
""" """
assert self.tp_size % remote_tp_size == 0, ( if self.tp_size >= remote_tp_size:
f"Local tensor parallel size {self.tp_size} is not divisible " assert self.tp_size % remote_tp_size == 0, (
f"by remote tensor parallel size {remote_tp_size}." f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
) )
return self.tp_size // remote_tp_size # P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio( def block_size_ratio(
self, self,
remote_block_size: int, remote_block_size: int,
) -> float: ) -> int:
""" """
Calculate the block size ratio between local and remote TP. Calculate the block size ratio between local and remote TP.
""" """
...@@ -279,19 +291,19 @@ class TpKVTopology: ...@@ -279,19 +291,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id( def tp_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size) return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id( def block_size_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> float: ) -> int:
remote_block_size = self.remote_block_size[remote_engine_id] remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size) return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool: def is_kv_replicated(self, engine_id: EngineId) -> bool:
""" """
Whether the KV cache is replicated across TP workers due to the Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads. number of TP workers being greater than the number of KV heads.
...@@ -299,24 +311,30 @@ class TpKVTopology: ...@@ -299,24 +311,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id] tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1 return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id) return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank( def get_target_remote_ranks(
self, self,
remote_tp_size: int, remote_tp_size: int,
) -> int: ) -> list[int]:
""" """
Get the remote TP rank (on P) that the current local TP rank Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. (on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
""" """
tp_ratio = self.tp_ratio(remote_tp_size) tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
def get_target_remote_rank_from_engine_id( # P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_ranks_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size) return self.get_target_remote_ranks(remote_tp_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment