Unverified Commit 8a4e5c5f authored by Zhonghua Deng's avatar Zhonghua Deng Committed by GitHub
Browse files

[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)


Signed-off-by: default avatarAbatom <abzhonghua@gmail.com>
parent 76b49444
This diff is collapsed.
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import os import os
import socket import socket
import threading import threading
import time
import uuid import uuid
from typing import Any
import aiohttp import aiohttp
import msgpack import msgpack
...@@ -12,12 +14,25 @@ import zmq ...@@ -12,12 +14,25 @@ import zmq
from quart import Quart, make_response, request from quart import Quart, make_response, request
count = 0 count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
decode_instances: dict[str, str] = {} # http_address: zmq_address decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
prefill_cv = threading.Condition() prefill_cv = threading.Condition()
decode_cv = threading.Condition() decode_cv = threading.Condition()
DEFAULT_PING_SECONDS = 5
def _remove_oldest_instances(instances: dict[str, Any]) -> None:
oldest_key = next(iter(instances), None)
while oldest_key is not None:
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
...@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket): ...@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket):
global prefill_instances global prefill_instances
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"] node = prefill_instances.pop(data["http_address"], None)
prefill_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(prefill_instances)
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"] node = decode_instances.pop(data["http_address"], None)
decode_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(decode_instances)
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
...@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket): ...@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket):
data, data,
) )
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
...@@ -105,12 +134,14 @@ async def handle_request(): ...@@ -105,12 +134,14 @@ async def handle_request():
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "
......
...@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) self.p2p_nccl_engine.send_tensor(
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, request_id + "#" + layer_name, kv_layer, remote_address,
kv_cache, remote_address) request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: if self.is_producer:
...@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context() no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids, return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context) no_compile_layers)
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
...@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear() self._requests_need_load.clear()
return meta return meta
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment