Commit 7cbb7097 authored by maxiao1's avatar maxiao1
Browse files

修复数据集推理时候decode侧卡住

parent c2e6f453
......@@ -252,7 +252,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
req_layer = f"{request.request_id}#{layer_name}"
with self.p2p_nccl_engine.recv_store_cv:
while req_layer not in self.p2p_nccl_engine.recv_split_nums:
self.p2p_nccl_engine.recv_store_cv.wait()
split_num = self.p2p_nccl_engine.recv_split_nums.get(req_layer)
for num in range(split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
......@@ -280,6 +286,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
......
......@@ -117,6 +117,7 @@ class P2pNcclEngine:
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
self.recv_split_nums: dict[str, int] = {}
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
......@@ -200,7 +201,6 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
......@@ -251,7 +251,8 @@ class P2pNcclEngine:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
# self.recv_store_cv.notify()
self.recv_store_cv.notify_all()
return True
else:
if self.send_type == "PUT":
......@@ -260,7 +261,7 @@ class P2pNcclEngine:
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify()
self.send_queue_cv.notify_all()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
......@@ -365,7 +366,14 @@ class P2pNcclEngine:
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
# self.tensor_split_num = data["tensor_split_num"]
parts = tensor_id.split('#')
request_id = parts[0]
layer_name = parts[1]
req_layer = f"{request_id}#{layer_name}"
self.recv_split_nums[req_layer] = data["tensor_split_num"]
with self.recv_store_cv:
self.recv_store_cv.notify_all()
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
......@@ -397,7 +405,8 @@ class P2pNcclEngine:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
#self.recv_store_cv.notify()
self.recv_store_cv.notify_all()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
......@@ -450,7 +459,7 @@ class P2pNcclEngine:
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self.send_queue_cv.notify_all()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
......@@ -590,20 +599,30 @@ class P2pNcclEngine:
"""
# Clear the buffer upon request completion.
# for request_id in finished_req_ids:
# for layer_name in forward_context.no_compile_layers:
# tensor_id = request_id + "#" + layer_name
# if tensor_id in self.recv_store:
# with self.recv_store_cv:
# tensor = self.recv_store.pop(tensor_id, None)
# self.send_request_id_to_tensor_ids.pop(
# request_id, None)
# self.recv_request_id_to_tensor_ids.pop(
# request_id, None)
# addr = 0
# if isinstance(tensor, tuple):
# addr, _, _ = tensor
# self.pool.free(addr)
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment