Unverified Commit 98b09ddc authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[NIXL][Bugfix] metrics & testing minor bug (#36051)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent cef1f302
...@@ -694,16 +694,18 @@ class TestNixlHandshake: ...@@ -694,16 +694,18 @@ class TestNixlHandshake:
) )
@pytest.mark.parametrize("local_tp_size", [1, 2]) @pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size( def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, default_vllm_config, dist_init self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch
): ):
""" """
Verify remote TP > local TP handshake succeeds with different Verify remote TP > local TP handshake succeeds with different
remote configurations. remote configurations.
""" """
monkeypatch.setattr(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size",
lambda: local_tp_size,
)
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector( connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
...@@ -738,10 +740,10 @@ class TestNixlHandshake: ...@@ -738,10 +740,10 @@ class TestNixlHandshake:
remote_agents = worker._nixl_handshake( remote_agents = worker._nixl_handshake(
host="localhost", host="localhost",
port=1234, port=1234,
remote_tp_size=2, remote_tp_size=4,
expected_engine_id=worker.REMOTE_ENGINE_ID, expected_engine_id=worker.REMOTE_ENGINE_ID,
) )
check_handshake(2) check_handshake(4)
# NOTE flexibility: a second remote with higher number of ranks is # NOTE flexibility: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but # discovered. This is not a scenario we actively support right now, but
...@@ -759,9 +761,8 @@ class TestNixlHandshake: ...@@ -759,9 +761,8 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla( def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, default_vllm_config, dist_init self, default_vllm_config, dist_init
): ):
""" """
Verify remote TP > local TP handshake succeeds with different Verify remote TP > local TP handshake succeeds with different
......
...@@ -1318,12 +1318,12 @@ class NixlConnectorWorker: ...@@ -1318,12 +1318,12 @@ class NixlConnectorWorker:
f"Expected {expected_engine_id}," f"Expected {expected_engine_id},"
f"received {metadata.engine_id}." f"received {metadata.engine_id}."
) )
setup_agent_time = time.perf_counter()
# Register Remote agent. # Register Remote agent.
remote_agent_name = self.add_remote_agent( remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size metadata, remote_rank, remote_tp_size
) )
setup_agent_time = time.perf_counter()
logger.debug( logger.debug(
"NIXL handshake: add agent took: %s", "NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time, setup_agent_time - got_metadata_time,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment