Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45cbc499
Unverified
Commit
45cbc499
authored
Feb 07, 2025
by
Lu Fang
Committed by
GitHub
Feb 07, 2025
Browse files
[Bugfix] Fix disagg hang caused by the prefill and decode communication issues (#12723)
Signed-off-by:
Lu Fang
<
lufang@fb.com
>
parent
932c6b74
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
47 deletions
+40
-47
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
...distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
+40
-47
No files found.
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
View file @
45cbc499
...
...
@@ -10,7 +10,6 @@
stop the prefill instance when the decode instance is slow.
"""
import
threading
import
time
from
collections
import
deque
from
typing
import
Deque
,
List
,
Optional
,
Union
...
...
@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
self
.
buffer_size
=
0
self
.
buffer_size_threshold
=
buffer_size_thresh
self
.
buffer_
lock
=
threading
.
Lock
()
self
.
buffer_
cv
=
threading
.
Condition
()
self
.
signal_pipe
=
signal_pipe
self
.
data_pipe
=
data_pipe
self
.
request_handling_thread
:
Optional
[
threading
.
Thread
]
=
None
...
...
@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
hidden
=
hidden
.
clone
()
buffer_item
=
[
input_tokens
,
roi
,
key
,
value
,
hidden
]
data_size
=
sum
([
self
.
_get_element_size
(
data
)
for
data
in
buffer_item
])
with
self
.
buffer_lock
:
for
data
in
buffer_item
:
self
.
buffer_size
+=
self
.
_get_element_size
(
data
)
with
self
.
buffer_cv
:
if
self
.
buffer_size
+
data_size
>
self
.
buffer_size_threshold
:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger
.
debug
(
"KV transfer buffer is full. Handling..."
)
while
self
.
buffer_size
+
data_size
>
self
.
buffer_size_threshold
:
self
.
buffer_cv
.
wait
()
self
.
buffer_size
+=
data_size
self
.
buffer
.
append
(
buffer_item
)
self
.
buffer_cv
.
notify
()
def
_is_end_signal
(
self
,
signal
):
return
signal
is
None
...
...
@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
roi
=
(
roi
>
0.5
)
tokens_roi_recver
=
[
input_tokens
,
roi
]
matched_length
=
0
def
is_buffer_available
(
tokens_roi_recver
:
List
[
torch
.
Tensor
],
)
->
bool
:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with
self
.
buffer_lock
:
for
_
in
range
(
len
(
self
.
buffer
)):
temp_length
=
self
.
_matches
(
self
.
buffer
[
0
],
tokens_roi_recver
)
if
temp_length
>
0
:
matched_length
=
temp_length
break
if
self
.
_matches
(
self
.
buffer
[
0
],
tokens_roi_recver
)
>
0
:
return
True
# rotate the element we just accessed to the end
self
.
buffer
.
rotate
(
-
1
)
return
False
if
matched_length
>
0
:
with
self
.
buffer_cv
:
while
not
is_buffer_available
(
tokens_roi_recver
):
logger
.
debug
(
"KV transfer buffer is not available. Waiting..."
)
self
.
buffer_cv
.
wait
()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item
=
self
.
buffer
.
popleft
()
for
tensor
in
matched_item
:
self
.
_send_tensor_and_dec_size
(
tensor
)
else
:
# no match, just send None
for
_
in
range
(
5
):
self
.
data_pipe
.
send_tensor
(
None
)
self
.
buffer_cv
.
notify
()
except
RuntimeError
as
e
:
if
'Connection closed by peer'
not
in
str
(
e
):
...
...
@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
return
[
input_tokens
,
roi
,
key
,
value
,
hidden
]
def
full_handler
(
self
):
time
.
sleep
(
0.001
)
def
insert
(
self
,
input_tokens
:
torch
.
Tensor
,
roi
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
hidden
:
torch
.
Tensor
)
->
None
:
if
self
.
buffer_size
>
self
.
buffer_size_threshold
:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger
.
debug
(
"KV transfer buffer is full. Handling..."
)
while
self
.
buffer_size
>
self
.
buffer_size_threshold
:
self
.
full_handler
()
self
.
_add_to_buffer
(
input_tokens
,
roi
,
key
,
value
,
hidden
)
# when calling the insert, the current process is a sender
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment