"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "e740d07f07d82983217077b89e23beaae134a30b"
Unverified Commit 45cbc499 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[Bugfix] Fix disagg hang caused by the prefill and decode communication issues (#12723)


Signed-off-by: default avatarLu Fang <lufang@fb.com>
parent 932c6b74
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
stop the prefill instance when the decode instance is slow. stop the prefill instance when the decode instance is slow.
""" """
import threading import threading
import time
from collections import deque from collections import deque
from typing import Deque, List, Optional, Union from typing import Deque, List, Optional, Union
...@@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float): buffer_size_thresh: float):
""" """
signal_pipe: on CPU signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request. CPU recv to listen to new request.
data_pipe: on device (e.g. GPU) data_pipe: on device (e.g. GPU)
""" """
...@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
self.buffer_size = 0 self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock() self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe self.signal_pipe = signal_pipe
self.data_pipe = data_pipe self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None self.request_handling_thread: Optional[threading.Thread] = None
...@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
hidden = hidden.clone() hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden] buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
with self.buffer_lock: self.buffer_size += data_size
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item) self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal): def _is_end_signal(self, signal):
return signal is None return signal is None
...@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
roi = (roi > 0.5) roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi] tokens_roi_recver = [input_tokens, roi]
matched_length = 0 def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool:
# perform input tokens and roi matching # perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1) # FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so # but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent. # the fix is not urgent.
with self.buffer_lock:
for _ in range(len(self.buffer)): for _ in range(len(self.buffer)):
if self._matches(self.buffer[0],
temp_length = self._matches(self.buffer[0], tokens_roi_recver) > 0:
tokens_roi_recver) return True
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end # rotate the element we just accessed to the end
self.buffer.rotate(-1) self.buffer.rotate(-1)
return False
if matched_length > 0:
# need to clone the tensor with self.buffer_cv:
# in case the tensor is freed before sending finishes while not is_buffer_available(tokens_roi_recver):
matched_item = self.buffer.popleft() logger.debug(
for tensor in matched_item: "KV transfer buffer is not available. Waiting...")
self._send_tensor_and_dec_size(tensor) self.buffer_cv.wait()
# need to clone the tensor
else: # in case the tensor is freed before sending finishes
# no match, just send None matched_item = self.buffer.popleft()
for _ in range(5): for tensor in matched_item:
self.data_pipe.send_tensor(None) self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e: except RuntimeError as e:
if 'Connection closed by peer' not in str(e): if 'Connection closed by peer' not in str(e):
...@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
return [input_tokens, roi, key, value, hidden] return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None: hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden) self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender # when calling the insert, the current process is a sender
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment