Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0a4fc73b
Unverified
Commit
0a4fc73b
authored
May 22, 2025
by
Byron Hsu
Committed by
GitHub
May 22, 2025
Browse files
[PD] Fix failure abort (#6535)
parent
a6970a17
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
141 additions
and
92 deletions
+141
-92
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+13
-89
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+105
-0
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+15
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+3
-1
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
0a4fc73b
...
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce
,
poll_and_all_reduce
,
prepare_abort
,
prepare_abort
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
...
@@ -321,11 +322,15 @@ class DecodeTransferQueue:
...
@@ -321,11 +322,15 @@ class DecodeTransferQueue:
gloo_group
:
ProcessGroup
,
gloo_group
:
ProcessGroup
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
torch
.
Tensor
,
metadata_buffers
:
torch
.
Tensor
,
scheduler
:
Scheduler
,
tree_cache
:
BasePrefixCache
,
):
):
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
gloo_group
=
gloo_group
self
.
gloo_group
=
gloo_group
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
metadata_buffers
=
metadata_buffers
self
.
metadata_buffers
=
metadata_buffers
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
def
add
(
self
,
req_conn
:
DecodeRequest
)
->
None
:
def
add
(
self
,
req_conn
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
req_conn
)
self
.
queue
.
append
(
req_conn
)
...
@@ -341,6 +346,14 @@ class DecodeTransferQueue:
...
@@ -341,6 +346,14 @@ class DecodeTransferQueue:
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
)
)
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
transferred_reqs
=
[]
transferred_reqs
=
[]
indices_to_remove
=
set
()
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
...
@@ -396,95 +409,6 @@ class DecodeTransferQueue:
...
@@ -396,95 +409,6 @@ class DecodeTransferQueue:
return
transferred_reqs
return
transferred_reqs
class
ScheduleBatchDisaggregationDecodeMixin
:
def
prepare_for_prebuilt_extend
(
self
:
ScheduleBatch
):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self
.
forward_mode
=
ForwardMode
.
EXTEND
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
pre_lens
=
[]
req_pool_indices
=
[]
# Pre-calculate total size
total_size
=
sum
(
req
.
extend_input_len
for
req
in
reqs
)
out_cache_loc
=
torch
.
empty
(
total_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Fill the tensor in one pass
offset
=
0
for
i
,
req
in
enumerate
(
reqs
):
req_pool_indices
.
append
(
req
.
req_pool_idx
)
chunk
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
req
.
extend_input_len
]
assert
(
offset
+
req
.
extend_input_len
<=
total_size
),
f
"Exceeds total size: offset=
{
offset
}
, req.extend_input_len=
{
req
.
extend_input_len
}
, total_size=
{
total_size
}
"
out_cache_loc
[
offset
:
offset
+
req
.
extend_input_len
]
=
chunk
offset
+=
req
.
extend_input_len
pre_len
=
len
(
req
.
prefix_indices
)
seq_len
=
len
(
req
.
origin_input_ids
)
+
max
(
0
,
len
(
req
.
output_ids
)
-
1
)
seq_lens
.
append
(
seq_len
)
if
len
(
req
.
output_ids
)
==
0
:
assert
(
seq_len
-
pre_len
==
req
.
extend_input_len
),
f
"seq_len=
{
seq_len
}
, pre_len=
{
pre_len
}
, req.extend_input_len=
{
req
.
extend_input_len
}
"
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
pre_lens
.
append
(
pre_len
)
req
.
extend_logprob_start_len
=
0
extend_input_logprob_token_ids
=
None
# Set fields
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
# Build sampling info
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
.
model_config
.
vocab_size
,
)
def
process_prebuilt_extend
(
self
:
ScheduleBatch
,
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
class
SchedulerDisaggregationDecodeMixin
:
class
SchedulerDisaggregationDecodeMixin
:
def
_prepare_idle_batch_and_run
(
self
,
batch
,
delay_process
=
False
):
def
_prepare_idle_batch_and_run
(
self
,
batch
,
delay_process
=
False
):
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
0 → 100644
View file @
0a4fc73b
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.server_args
import
ServerArgs
class
ScheduleBatchDisaggregationDecodeMixin
:
def
prepare_for_prebuilt_extend
(
self
:
ScheduleBatch
):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self
.
forward_mode
=
ForwardMode
.
EXTEND
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
pre_lens
=
[]
req_pool_indices
=
[]
# Pre-calculate total size
total_size
=
sum
(
req
.
extend_input_len
for
req
in
reqs
)
out_cache_loc
=
torch
.
empty
(
total_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
# Fill the tensor in one pass
offset
=
0
for
i
,
req
in
enumerate
(
reqs
):
req_pool_indices
.
append
(
req
.
req_pool_idx
)
chunk
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
req
.
extend_input_len
]
assert
(
offset
+
req
.
extend_input_len
<=
total_size
),
f
"Exceeds total size: offset=
{
offset
}
, req.extend_input_len=
{
req
.
extend_input_len
}
, total_size=
{
total_size
}
"
out_cache_loc
[
offset
:
offset
+
req
.
extend_input_len
]
=
chunk
offset
+=
req
.
extend_input_len
pre_len
=
len
(
req
.
prefix_indices
)
seq_len
=
len
(
req
.
origin_input_ids
)
+
max
(
0
,
len
(
req
.
output_ids
)
-
1
)
seq_lens
.
append
(
seq_len
)
if
len
(
req
.
output_ids
)
==
0
:
assert
(
seq_len
-
pre_len
==
req
.
extend_input_len
),
f
"seq_len=
{
seq_len
}
, pre_len=
{
pre_len
}
, req.extend_input_len=
{
req
.
extend_input_len
}
"
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
pre_lens
.
append
(
pre_len
)
req
.
extend_logprob_start_len
=
0
extend_input_logprob_token_ids
=
None
# Set fields
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
# Build sampling info
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
.
model_config
.
vocab_size
,
)
def
process_prebuilt_extend
(
self
:
ScheduleBatch
,
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
python/sglang/srt/disaggregation/utils.py
View file @
0a4fc73b
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
import
os
import
random
import
warnings
import
warnings
from
collections
import
deque
from
collections
import
deque
from
enum
import
Enum
from
enum
import
Enum
...
@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
...
@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
FakeBootstrapHost
=
"2.2.2.2"
FakeBootstrapHost
=
"2.2.2.2"
# env var for testing failure, convert to float explicitly
FAILURE_PROB
=
float
(
os
.
getenv
(
"DISAGGREGATION_TEST_FAILURE_PROB"
,
0
))
class
DisaggregationMode
(
Enum
):
class
DisaggregationMode
(
Enum
):
NULL
=
"null"
NULL
=
"null"
...
@@ -23,6 +28,15 @@ class DisaggregationMode(Enum):
...
@@ -23,6 +28,15 @@ class DisaggregationMode(Enum):
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
# at a certain prob, the poll is failed to simulate failure
if
FAILURE_PROB
>
0
:
from
sglang.srt.disaggregation.base
import
KVPoll
polls
=
[
int
(
KVPoll
.
Failed
)
if
random
.
random
()
<
FAILURE_PROB
else
int
(
poller
.
poll
())
for
poller
in
pollers
]
else
:
polls
=
[
int
(
poller
.
poll
())
for
poller
in
pollers
]
polls
=
[
int
(
poller
.
poll
())
for
poller
in
pollers
]
tensor_to_reduce
=
torch
.
tensor
(
polls
,
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
tensor_to_reduce
=
torch
.
tensor
(
polls
,
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
dist
.
all_reduce
(
tensor_to_reduce
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
gloo_group
)
dist
.
all_reduce
(
tensor_to_reduce
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
gloo_group
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
0a4fc73b
...
@@ -48,7 +48,9 @@ from sglang.global_config import global_config
...
@@ -48,7 +48,9 @@ from sglang.global_config import global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.decode
import
ScheduleBatchDisaggregationDecodeMixin
from
sglang.srt.disaggregation.decode_schedule_batch_mixin
import
(
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
...
...
python/sglang/srt/managers/scheduler.py
View file @
0a4fc73b
...
@@ -582,6 +582,8 @@ class Scheduler(
...
@@ -582,6 +582,8 @@ class Scheduler(
gloo_group
=
self
.
attn_tp_cpu_group
,
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
metadata_buffers
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
)
)
# The decode requests pending for pre-allocation
# The decode requests pending for pre-allocation
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
0a4fc73b
...
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
...
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
def
cache_finished_req
(
self
,
req
:
Req
):
def
cache_finished_req
(
self
,
req
:
Req
):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
req
.
req_pool_idx
,
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
:
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
),
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment