Unverified Commit 8a4e5c5f authored by Zhonghua Deng's avatar Zhonghua Deng Committed by GitHub
Browse files

[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)


Signed-off-by: default avatarAbatom <abzhonghua@gmail.com>
parent 76b49444
This diff is collapsed.
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import os import os
import socket import socket
import threading import threading
import time
import uuid import uuid
from typing import Any
import aiohttp import aiohttp
import msgpack import msgpack
...@@ -12,12 +14,25 @@ import zmq ...@@ -12,12 +14,25 @@ import zmq
from quart import Quart, make_response, request from quart import Quart, make_response, request
count = 0 count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
decode_instances: dict[str, str] = {} # http_address: zmq_address decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
prefill_cv = threading.Condition() prefill_cv = threading.Condition()
decode_cv = threading.Condition() decode_cv = threading.Condition()
DEFAULT_PING_SECONDS = 5
def _remove_oldest_instances(instances: dict[str, Any]) -> None:
oldest_key = next(iter(instances), None)
while oldest_key is not None:
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
...@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket): ...@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket):
global prefill_instances global prefill_instances
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"] node = prefill_instances.pop(data["http_address"], None)
prefill_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(prefill_instances)
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"] node = decode_instances.pop(data["http_address"], None)
decode_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(decode_instances)
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
...@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket): ...@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket):
data, data,
) )
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
...@@ -105,12 +134,14 @@ async def handle_request(): ...@@ -105,12 +134,14 @@ async def handle_request():
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "
......
...@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) self.p2p_nccl_engine.send_tensor(
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, request_id + "#" + layer_name, kv_layer, remote_address,
kv_cache, remote_address) request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: if self.is_producer:
...@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context() no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids, return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context) no_compile_layers)
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
...@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear() self._requests_need_load.clear()
return meta return meta
......
...@@ -8,7 +8,8 @@ import time ...@@ -8,7 +8,8 @@ import time
import typing import typing
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from dataclasses import dataclass
from typing import Any, Optional
import msgpack import msgpack
import torch import torch
...@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ...@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
...@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
...@@ -112,24 +119,26 @@ class P2pNcclEngine: ...@@ -112,24 +119,26 @@ class P2pNcclEngine:
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = float(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) self.config.get_from_extra_config("mem_pool_size_gb",
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) * DEFAULT_MEM_POOL_SIZE_GB))
1024**3) # GB self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options: # The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC. # PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT") self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET": if self.send_type == "GET":
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {} self.send_store: dict[str, torch.Tensor] = {}
else: else:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self.send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
...@@ -146,13 +155,12 @@ class P2pNcclEngine: ...@@ -146,13 +155,12 @@ class P2pNcclEngine:
"nccl_num_channels", "8") "nccl_num_channels", "8")
self._listener_thread = threading.Thread( self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True) target=self.listen_for_requests, daemon=True)
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
if port_offset == 0 and self.proxy_address != "": if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping, self._ping_thread = threading.Thread(target=self.ping, daemon=True)
daemon=True)
self._ping_thread.start() self._ping_thread.start()
logger.info( logger.info(
...@@ -162,7 +170,7 @@ class P2pNcclEngine: ...@@ -162,7 +170,7 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address, self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels) self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect(self, remote_address: typing.Optional[str] = None): def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None assert remote_address is not None
if remote_address not in self.socks: if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
...@@ -184,7 +192,7 @@ class P2pNcclEngine: ...@@ -184,7 +192,7 @@ class P2pNcclEngine:
comm: ncclComm_t = self.nccl.ncclCommInitRank( comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank) 2, unique_id, rank)
self.comms[remote_address] = (comm, rank) self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank) self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address], self.comms[remote_address]
...@@ -194,20 +202,31 @@ class P2pNcclEngine: ...@@ -194,20 +202,31 @@ class P2pNcclEngine:
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify() self.recv_store_cv.notify()
return True return True
else:
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor,
slot_mapping=slot_mapping,
is_mla=is_mla)
if self.send_type == "PUT": if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address) return self.send_sync(item)
elif self.send_type == "PUT_ASYNC":
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv: with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor]) self.send_queue.append(item)
self.send_queue_cv.notify() self.send_queue_cv.notify()
else: # GET return True
# GET
with self.send_store_cv: with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel() tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size while (self.buffer_size + tensor_size
...@@ -220,18 +239,17 @@ class P2pNcclEngine: ...@@ -220,18 +239,17 @@ class P2pNcclEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address, tensor_id, tensor_size, self.buffer_size,
self.buffer_size, oldest_tenser_size, self.rank) oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size self.buffer_size += tensor_size
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
remote_address, tensor_id, tensor_size, tensor.shape, tensor_id, tensor_size, tensor.shape, self.rank,
self.rank, self.buffer_size, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
return True return True
def recv_tensor( def recv_tensor(
...@@ -267,7 +285,7 @@ class P2pNcclEngine: ...@@ -267,7 +285,7 @@ class P2pNcclEngine:
return None return None
if remote_address not in self.socks: if remote_address not in self.socks:
self._create_connect(remote_address) self.create_connect(remote_address)
sock = self.socks[remote_address] sock = self.socks[remote_address]
comm, rank = self.comms[remote_address] comm, rank = self.comms[remote_address]
...@@ -282,18 +300,21 @@ class P2pNcclEngine: ...@@ -282,18 +300,21 @@ class P2pNcclEngine:
remote_address, tensor_id, data["ret"]) remote_address, tensor_id, data["ret"])
return None return None
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"], tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]), dtype=getattr(torch, data["dtype"]),
device=self.device) device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream) self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor return tensor
def _listen_for_requests(self): def listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll())
if self.router_socket in socks: if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart() remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
...@@ -305,9 +326,9 @@ class P2pNcclEngine: ...@@ -305,9 +326,9 @@ class P2pNcclEngine:
comm: ncclComm_t = self.nccl.ncclCommInitRank( comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank) 2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank) self.comms[remote_address.decode()] = (comm, rank)
logger.info( logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", self.zmq_address, remote_address.decode(),
self.zmq_address, remote_address.decode(), rank) rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
try: try:
...@@ -316,10 +337,9 @@ class P2pNcclEngine: ...@@ -316,10 +337,9 @@ class P2pNcclEngine:
dtype=getattr( dtype=getattr(
torch, data["dtype"]), torch, data["dtype"]),
device=self.device) device=self.device)
self.router_socket.send_multipart( self.router_socket.send_multipart([remote_address, b"0"])
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()] comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream) self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel() tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size if (self.buffer_size + tensor_size
> self.buffer_size_threshold): > self.buffer_size_threshold):
...@@ -334,17 +354,16 @@ class P2pNcclEngine: ...@@ -334,17 +354,16 @@ class P2pNcclEngine:
self.buffer_size += tensor_size self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart( self.router_socket.send_multipart([remote_address, b"1"])
[remote_address, b"1"])
tensor = None tensor = None
logger.warning( logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, "data:%s", self.zmq_address, remote_address.decode(),
remote_address.decode(), data) data)
with self.recv_store_cv: with self.recv_store_cv:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id) self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify() self.recv_store_cv.notify()
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
...@@ -355,12 +374,11 @@ class P2pNcclEngine: ...@@ -355,12 +374,11 @@ class P2pNcclEngine:
data = { data = {
"ret": 0, "ret": 0,
"shape": tensor.shape, "shape": tensor.shape,
"dtype": "dtype": str(tensor.dtype).replace("torch.", "")
str(tensor.dtype).replace("torch.", "")
} }
# LRU # LRU
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id) self.have_sent_tensor_id(tensor_id)
else: else:
data = {"ret": 1} data = {"ret": 1}
...@@ -369,34 +387,34 @@ class P2pNcclEngine: ...@@ -369,34 +387,34 @@ class P2pNcclEngine:
if data["ret"] == 0: if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()] comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1, self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream) self.send_stream)
else: else:
logger.warning( logger.warning(
"🚧Unexpected, Received message from %s, data:%s", "🚧Unexpected, Received message from %s, data:%s",
remote_address, data) remote_address, data)
def _have_sent_tensor_id(self, tensor_id: str): def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0] request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids: if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id) self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str): def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0] request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids: if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self): def send_async(self):
while True: while True:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft() item = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address) self.send_sync(item)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -409,22 +427,21 @@ class P2pNcclEngine: ...@@ -409,22 +427,21 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank) " to be empty, rank:%d", duration * 1000, self.rank)
def _send_sync( def send_sync(self, item: SendQueueItem) -> bool:
self, if item.remote_address is None:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
return False return False
if remote_address not in self.socks: if item.remote_address not in self.socks:
self._create_connect(remote_address) self.create_connect(item.remote_address)
sock = self.socks[remote_address] with self.send_stream:
comm, rank = self.comms[remote_address] tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = { data = {
"cmd": "PUT", "cmd": "PUT",
"tensor_id": tensor_id, "tensor_id": item.tensor_id,
"shape": tensor.shape, "shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "") "dtype": str(tensor.dtype).replace("torch.", "")
} }
...@@ -435,20 +452,21 @@ class P2pNcclEngine: ...@@ -435,20 +452,21 @@ class P2pNcclEngine:
logger.error( logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape, self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3, tensor.element_size() * tensor.numel() / 1024**3,
response.decode()) response.decode())
return False return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id) self.have_sent_tensor_id(item.tensor_id)
return True return True
def get_finished( def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext" self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]: ) -> tuple[Optional[set[str]], Optional[set[str]]]:
""" """
Notifies worker-side connector ids of requests that have Notifies worker-side connector ids of requests that have
...@@ -463,7 +481,7 @@ class P2pNcclEngine: ...@@ -463,7 +481,7 @@ class P2pNcclEngine:
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store: if tensor_id in self.recv_store:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -472,7 +490,6 @@ class P2pNcclEngine: ...@@ -472,7 +490,6 @@ class P2pNcclEngine:
request_id, None) request_id, None)
self.recv_request_id_to_tensor_ids.pop( self.recv_request_id_to_tensor_ids.pop(
request_id, None) request_id, None)
addr = 0
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
...@@ -485,7 +502,7 @@ class P2pNcclEngine: ...@@ -485,7 +502,7 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def ping(self):
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
...@@ -499,7 +516,7 @@ class P2pNcclEngine: ...@@ -499,7 +516,7 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
time.sleep(3) time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
...@@ -512,7 +529,7 @@ class P2pNcclEngine: ...@@ -512,7 +529,7 @@ class P2pNcclEngine:
comm, cudaStream_t(stream.cuda_stream)) comm, cudaStream_t(stream.cuda_stream))
stream.synchronize() stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
...@@ -531,3 +548,21 @@ class P2pNcclEngine: ...@@ -531,3 +548,21 @@ class P2pNcclEngine:
self._send_thread.join() self._send_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment