Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
50188092
Unverified
Commit
50188092
authored
Aug 09, 2025
by
Cheng Wan
Committed by
GitHub
Aug 09, 2025
Browse files
[DP] fix: engine crash when decode batch is padded (#8995)
parent
326a901d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
18 deletions
+33
-18
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+7
-13
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+26
-5
No files found.
python/sglang/srt/layers/communicator.py
View file @
50188092
...
...
@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
):
if
residual_input_mode
==
ScatterMode
.
SCATTERED
and
context
.
attn_tp_size
>
1
:
residual
,
local_residual
=
(
f
or
ward_batch
.
gathered_buffer
[
:
forward_batch
.
input_ids
.
shape
[
0
]
].
clone
(
),
t
or
ch
.
empty_like
(
forward_batch
.
gathered_buffer
[
:
forward_batch
.
input_ids
.
shape
[
0
]
]
),
residual
,
)
attn_tp_all_gather_into_tensor
(
residual
,
local_residual
)
...
...
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
use_layer_norm_before_gather
=
context
.
attn_tp_size
==
1
if
use_layer_norm_before_gather
:
residual
.
copy_
(
hidden_states
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
layernorm
(
hidden_states
)
if
use_layer_norm_before_gather
and
hidden_states
.
shape
[
0
]
!=
0
:
residual
=
hidden_states
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
torch
.
empty_like
(
forward_batch
.
gathered_buffer
)
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
...
...
@@ -552,10 +550,6 @@ class CommunicateSummableTensorPairFn:
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
if
hidden_states
.
data_ptr
()
is
global_hidden_states
.
data_ptr
():
hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
allow_reduce_scatter
and
forward_batch
.
dp_padding_mode
.
is_max_len
():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor
(
hidden_states
,
global_hidden_states
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
50188092
...
...
@@ -653,12 +653,30 @@ class ForwardBatch:
else
:
num_tokens
=
global_num_tokens
[
0
]
if
self
.
forward_mode
.
is_decode
():
setattr
(
self
,
"raw_bs"
,
self
.
batch_size
)
self
.
batch_size
=
num_tokens
bs
=
self
.
batch_size
if
self
.
forward_mode
.
is_decode
():
if
self
.
is_extend_in_batch
and
dp_padding_mode
.
is_max_len
():
setattr
(
self
,
"_original_forward_mode"
,
self
.
forward_mode
)
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
extend_num_tokens
=
bs
self
.
extend_seq_lens
=
torch
.
full_like
(
self
.
seq_lens
,
1
)
self
.
extend_prefix_lens
=
self
.
seq_lens
-
1
self
.
extend_start_loc
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
seq_lens
.
device
)
self
.
extend_prefix_lens_cpu
=
self
.
extend_prefix_lens
.
cpu
()
self
.
extend_seq_lens_cpu
=
self
.
extend_seq_lens
.
cpu
()
self
.
extend_logprob_start_lens_cpu
=
self
.
extend_prefix_lens_cpu
else
:
setattr
(
self
,
"_original_batch_size"
,
self
.
batch_size
)
if
self
.
spec_info
is
not
None
:
bs
=
self
.
batch_size
=
(
num_tokens
//
self
.
spec_info
.
num_tokens_per_batch
)
else
:
bs
=
self
.
batch_size
=
num_tokens
# padding
self
.
input_ids
=
self
.
_pad_tensor_to_size
(
self
.
input_ids
,
num_tokens
)
self
.
req_pool_indices
=
self
.
_pad_tensor_to_size
(
self
.
req_pool_indices
,
bs
)
...
...
@@ -689,6 +707,7 @@ class ForwardBatch:
if
self
.
mrope_positions
is
not
None
:
self
.
mrope_positions
=
self
.
_pad_tensor_to_size
(
self
.
mrope_positions
,
bs
)
# TODO: check if we need to pad other tensors
if
self
.
extend_seq_lens
is
not
None
:
self
.
extend_seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
extend_seq_lens
,
bs
)
...
...
@@ -712,7 +731,9 @@ class ForwardBatch:
def
post_forward_mlp_sync_batch
(
self
,
logits_output
:
LogitsProcessorOutput
):
bs
=
getattr
(
self
,
"raw_bs"
,
self
.
batch_size
)
self
.
forward_mode
=
getattr
(
self
,
"_original_forward_mode"
,
self
.
forward_mode
)
self
.
batch_size
=
getattr
(
self
,
"_original_batch_size"
,
self
.
batch_size
)
bs
=
self
.
batch_size
if
self
.
spec_info
is
not
None
:
if
self
.
forward_mode
.
is_decode
():
# draft
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment