Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3bde1010
Unverified
Commit
3bde1010
authored
May 21, 2025
by
Byron Hsu
Committed by
GitHub
May 21, 2025
Browse files
[PD] Abort request if transfer fails (#6504)
parent
75135580
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
4 deletions
+84
-4
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+30
-2
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+2
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+24
-2
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+15
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+13
-0
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
3bde1010
...
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import (
is_mla_backend
,
is_mla_backend
,
kv_to_page_indices
,
kv_to_page_indices
,
poll_and_all_reduce
,
poll_and_all_reduce
,
prepare_abort
,
)
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
...
@@ -178,7 +179,17 @@ class DecodePreallocQueue:
...
@@ -178,7 +179,17 @@ class DecodePreallocQueue:
elif
poll
==
KVPoll
.
WaitingForInput
:
elif
poll
==
KVPoll
.
WaitingForInput
:
decode_req
.
waiting_for_input
=
True
decode_req
.
waiting_for_input
=
True
elif
poll
==
KVPoll
.
Failed
:
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Handshake failed"
)
error_message
=
f
"Decode handshake failed for request rank=
{
self
.
tp_rank
}
{
decode_req
.
req
.
rid
=
}
{
decode_req
.
req
.
bootstrap_room
=
}
"
try
:
decode_req
.
kv_receiver
.
failure_exception
()
except
Exception
as
e
:
error_message
+=
f
" with exception
{
e
}
"
logger
.
error
(
error_message
)
prepare_abort
(
decode_req
.
req
,
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
,
)
def
pop_preallocated
(
self
)
->
List
[
DecodeRequest
]:
def
pop_preallocated
(
self
)
->
List
[
DecodeRequest
]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
"""Pop the preallocated requests from the pending queue (FIFO)."""
...
@@ -333,7 +344,24 @@ class DecodeTransferQueue:
...
@@ -333,7 +344,24 @@ class DecodeTransferQueue:
indices_to_remove
=
set
()
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
if
poll
==
KVPoll
.
Failed
:
if
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Transfer failed"
)
error_message
=
f
"Decode transfer failed for request rank=
{
self
.
tp_rank
}
{
decode_req
.
req
.
rid
=
}
{
decode_req
.
req
.
bootstrap_room
=
}
"
try
:
decode_req
.
kv_receiver
.
failure_exception
()
except
Exception
as
e
:
error_message
+=
f
" with exception
{
e
}
"
logger
.
error
(
error_message
)
prepare_abort
(
decode_req
.
req
,
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
,
)
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
# unlock the kv cache or it will have memory leak
self
.
tree_cache
.
cache_finished_req
(
decode_req
.
req
)
indices_to_remove
.
add
(
i
)
continue
elif
poll
==
KVPoll
.
Success
:
elif
poll
==
KVPoll
.
Success
:
# pop and push it to waiting queue
# pop and push it to waiting queue
idx
=
decode_req
.
metadata_buffer_index
idx
=
decode_req
.
metadata_buffer_index
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
3bde1010
...
@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender):
...
@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender):
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
def
failure_exception
(
self
):
def
failure_exception
(
self
):
# TODO: raise a real exception
raise
Exception
(
"Fake KVSender Exception"
)
raise
Exception
(
"Fake KVSender Exception"
)
...
@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
def
failure_exception
(
self
):
def
failure_exception
(
self
):
# TODO: raise a real exception
raise
Exception
(
"Fake KVReceiver Exception"
)
raise
Exception
(
"Fake KVReceiver Exception"
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
3bde1010
...
@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
kv_to_page_indices
,
kv_to_page_indices
,
kv_to_page_num
,
kv_to_page_num
,
poll_and_all_reduce
,
poll_and_all_reduce
,
prepare_abort
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
...
@@ -157,7 +158,18 @@ class PrefillBootstrapQueue:
...
@@ -157,7 +158,18 @@ class PrefillBootstrapQueue:
if
poll
==
KVPoll
.
Bootstrapping
:
if
poll
==
KVPoll
.
Bootstrapping
:
continue
continue
elif
poll
==
KVPoll
.
Failed
:
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Bootstrap failed"
)
error_message
=
f
"Prefill bootstrap failed for request rank=
{
self
.
tp_rank
}
{
req
.
rid
=
}
{
req
.
bootstrap_room
=
}
"
try
:
req
.
disagg_kv_sender
.
failure_exception
()
except
Exception
as
e
:
error_message
+=
f
" with exception
{
e
}
"
logger
.
error
(
error_message
)
prepare_abort
(
req
,
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
self
.
scheduler
.
stream_output
([
req
],
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
continue
# KV.WaitingForInput
# KV.WaitingForInput
num_kv_indices
=
len
(
req
.
origin_input_ids
)
num_kv_indices
=
len
(
req
.
origin_input_ids
)
...
@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin:
# FIXME: clean up req's data in transfer engine
# FIXME: clean up req's data in transfer engine
done_reqs
.
append
(
req
)
done_reqs
.
append
(
req
)
elif
poll
==
KVPoll
.
Failed
:
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Transferring failed"
)
error_message
=
f
"Prefill transfer failed for request rank=
{
self
.
tp_rank
}
{
req
.
rid
=
}
{
req
.
bootstrap_room
=
}
"
try
:
req
.
disagg_kv_sender
.
failure_exception
()
except
Exception
as
e
:
error_message
+=
f
" with exception
{
e
}
"
logger
.
warning
(
error_message
)
self
.
tree_cache
.
cache_finished_req
(
req
)
# unlock the tree
prepare_abort
(
req
,
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
done_reqs
.
append
(
req
)
for
req
in
done_reqs
:
for
req
in
done_reqs
:
self
.
disagg_prefill_bootstrap_queue
.
req_to_metadata_buffer_idx_allocator
.
free
(
self
.
disagg_prefill_bootstrap_queue
.
req_to_metadata_buffer_idx_allocator
.
free
(
...
...
python/sglang/srt/disaggregation/utils.py
View file @
3bde1010
...
@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
...
@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
return
isinstance
(
target_kv_pool
,
MLATokenToKVPool
)
return
isinstance
(
target_kv_pool
,
MLATokenToKVPool
)
def
prepare_abort
(
req
:
Req
,
error_message
:
str
,
status_code
=
None
):
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
# populate finish metadata and stream output
req
.
finished_reason
=
FINISH_ABORT
(
error_message
,
status_code
)
if
req
.
return_logprob
:
req
.
input_token_logprobs_val
=
[]
req
.
input_token_logprobs_idx
=
[]
req
.
input_top_logprobs_val
=
[]
req
.
input_top_logprobs_idx
=
[]
req
.
input_token_ids_logprobs_val
=
[]
req
.
input_token_ids_logprobs_idx
=
[]
python/sglang/srt/managers/scheduler.py
View file @
3bde1010
...
@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode
,
DisaggregationMode
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
prepare_abort
,
)
)
from
sglang.srt.distributed
import
get_pp_group
,
get_world_group
from
sglang.srt.distributed
import
get_pp_group
,
get_world_group
from
sglang.srt.hf_transformers_utils
import
(
from
sglang.srt.hf_transformers_utils
import
(
...
@@ -935,6 +936,18 @@ class Scheduler(
...
@@ -935,6 +936,18 @@ class Scheduler(
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
if
self
.
disaggregation_mode
!=
DisaggregationMode
.
NULL
:
# Invalid request for disaggregated mode
if
recv_req
.
bootstrap_room
is
None
:
error_message
=
(
f
"Invalid request: Disaggregated request received without "
f
"boostrap room id.
{
req
.
rid
=
}
"
)
logger
.
error
(
error_message
)
prepare_abort
(
req
,
error_message
)
self
.
stream_output
([
req
],
req
.
return_logprob
)
return
if
(
if
(
recv_req
.
session_params
is
not
None
recv_req
.
session_params
is
not
None
and
recv_req
.
session_params
.
id
is
not
None
and
recv_req
.
session_params
.
id
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment