Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0a4fc73b
Unverified
Commit
0a4fc73b
authored
May 22, 2025
by
Byron Hsu
Committed by
GitHub
May 22, 2025
Browse files
[PD] Fix failure abort (#6535)
parent
a6970a17
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
141 additions
and
92 deletions
+141
-92
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+13
-89
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+105
-0
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+15
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+3
-1
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
0a4fc73b
...
...
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce
,
prepare_abort
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
...
...
@@ -321,11 +322,15 @@ class DecodeTransferQueue:
gloo_group
:
ProcessGroup
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
torch
.
Tensor
,
scheduler
:
Scheduler
,
tree_cache
:
BasePrefixCache
,
):
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
gloo_group
=
gloo_group
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
metadata_buffers
=
metadata_buffers
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
def
add
(
self
,
req_conn
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
req_conn
)
...
...
@@ -341,6 +346,14 @@ class DecodeTransferQueue:
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
)
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
transferred_reqs
=
[]
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
...
...
@@ -396,95 +409,6 @@ class DecodeTransferQueue:
return
transferred_reqs
class
ScheduleBatchDisaggregationDecodeMixin
:
def
prepare_for_prebuilt_extend
(
self
:
ScheduleBatch
):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self
.
forward_mode
=
ForwardMode
.
EXTEND
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
pre_lens
=
[]
req_pool_indices
=
[]
# Pre-calculate total size
total_size
=
sum
(
req
.
extend_input_len
for
req
in
reqs
)
out_cache_loc
=
torch
.
empty
(
total_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Fill the tensor in one pass
offset
=
0
for
i
,
req
in
enumerate
(
reqs
):
req_pool_indices
.
append
(
req
.
req_pool_idx
)
chunk
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
req
.
extend_input_len
]
assert
(
offset
+
req
.
extend_input_len
<=
total_size
),
f
"Exceeds total size: offset=
{
offset
}
, req.extend_input_len=
{
req
.
extend_input_len
}
, total_size=
{
total_size
}
"
out_cache_loc
[
offset
:
offset
+
req
.
extend_input_len
]
=
chunk
offset
+=
req
.
extend_input_len
pre_len
=
len
(
req
.
prefix_indices
)
seq_len
=
len
(
req
.
origin_input_ids
)
+
max
(
0
,
len
(
req
.
output_ids
)
-
1
)
seq_lens
.
append
(
seq_len
)
if
len
(
req
.
output_ids
)
==
0
:
assert
(
seq_len
-
pre_len
==
req
.
extend_input_len
),
f
"seq_len=
{
seq_len
}
, pre_len=
{
pre_len
}
, req.extend_input_len=
{
req
.
extend_input_len
}
"
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
pre_lens
.
append
(
pre_len
)
req
.
extend_logprob_start_len
=
0
extend_input_logprob_token_ids
=
None
# Set fields
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
# Build sampling info
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
.
model_config
.
vocab_size
,
)
def
process_prebuilt_extend
(
self
:
ScheduleBatch
,
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
class
SchedulerDisaggregationDecodeMixin
:
def
_prepare_idle_batch_and_run
(
self
,
batch
,
delay_process
=
False
):
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
0 → 100644
View file @
0a4fc73b
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.server_args
import
ServerArgs
class
ScheduleBatchDisaggregationDecodeMixin
:
def
prepare_for_prebuilt_extend
(
self
:
ScheduleBatch
):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self
.
forward_mode
=
ForwardMode
.
EXTEND
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
pre_lens
=
[]
req_pool_indices
=
[]
# Pre-calculate total size
total_size
=
sum
(
req
.
extend_input_len
for
req
in
reqs
)
out_cache_loc
=
torch
.
empty
(
total_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Fill the tensor in one pass
offset
=
0
for
i
,
req
in
enumerate
(
reqs
):
req_pool_indices
.
append
(
req
.
req_pool_idx
)
chunk
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
req
.
extend_input_len
]
assert
(
offset
+
req
.
extend_input_len
<=
total_size
),
f
"Exceeds total size: offset=
{
offset
}
, req.extend_input_len=
{
req
.
extend_input_len
}
, total_size=
{
total_size
}
"
out_cache_loc
[
offset
:
offset
+
req
.
extend_input_len
]
=
chunk
offset
+=
req
.
extend_input_len
pre_len
=
len
(
req
.
prefix_indices
)
seq_len
=
len
(
req
.
origin_input_ids
)
+
max
(
0
,
len
(
req
.
output_ids
)
-
1
)
seq_lens
.
append
(
seq_len
)
if
len
(
req
.
output_ids
)
==
0
:
assert
(
seq_len
-
pre_len
==
req
.
extend_input_len
),
f
"seq_len=
{
seq_len
}
, pre_len=
{
pre_len
}
, req.extend_input_len=
{
req
.
extend_input_len
}
"
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
pre_lens
.
append
(
pre_len
)
req
.
extend_logprob_start_len
=
0
extend_input_logprob_token_ids
=
None
# Set fields
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
# Build sampling info
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
.
model_config
.
vocab_size
,
)
def
process_prebuilt_extend
(
self
:
ScheduleBatch
,
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
python/sglang/srt/disaggregation/utils.py
View file @
0a4fc73b
from
__future__
import
annotations
import
dataclasses
import
os
import
random
import
warnings
from
collections
import
deque
from
enum
import
Enum
...
...
@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
FakeBootstrapHost
=
"2.2.2.2"
# env var for testing failure, convert to float explicitly
FAILURE_PROB
=
float
(
os
.
getenv
(
"DISAGGREGATION_TEST_FAILURE_PROB"
,
0
))
class
DisaggregationMode
(
Enum
):
NULL
=
"null"
...
...
@@ -23,6 +28,15 @@ class DisaggregationMode(Enum):
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
# at a certain prob, the poll is failed to simulate failure
if
FAILURE_PROB
>
0
:
from
sglang.srt.disaggregation.base
import
KVPoll
polls
=
[
int
(
KVPoll
.
Failed
)
if
random
.
random
()
<
FAILURE_PROB
else
int
(
poller
.
poll
())
for
poller
in
pollers
]
else
:
polls
=
[
int
(
poller
.
poll
())
for
poller
in
pollers
]
tensor_to_reduce
=
torch
.
tensor
(
polls
,
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
dist
.
all_reduce
(
tensor_to_reduce
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
gloo_group
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
0a4fc73b
...
...
@@ -48,7 +48,9 @@ from sglang.global_config import global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.decode
import
ScheduleBatchDisaggregationDecodeMixin
from
sglang.srt.disaggregation.decode_schedule_batch_mixin
import
(
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
...
...
python/sglang/srt/managers/scheduler.py
View file @
0a4fc73b
...
...
@@ -582,6 +582,8 @@ class Scheduler(
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
)
# The decode requests pending for pre-allocation
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
0a4fc73b
...
...
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
def
cache_finished_req
(
self
,
req
:
Req
):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
req
.
req_pool_idx
,
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
:
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
),
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment