Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3d7f4b6
Unverified
Commit
a3d7f4b6
authored
May 27, 2025
by
Cheng Wan
Committed by
GitHub
May 27, 2025
Browse files
fix communicator for non-dp lm head (#6662)
parent
b18416fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
5 deletions
+13
-5
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+12
-4
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
No files found.
python/sglang/srt/layers/communicator.py
View file @
a3d7f4b6
...
...
@@ -50,10 +50,18 @@ class ScatterMode(Enum):
FULL
=
auto
()
@
staticmethod
def
model_input_
output
():
"""The scatter mode for model
forward pass input and out
put data"""
def
model_input_
mode
():
"""The scatter mode for model
in
put data"""
return
ScatterMode
.
TP_ATTN_FULL
@
staticmethod
def
model_output_mode
():
"""The scatter mode for model output data"""
if
global_server_args_dict
[
"enable_dp_lm_head"
]:
return
ScatterMode
.
TP_ATTN_FULL
else
:
return
ScatterMode
.
FULL
@
dataclass
class
_LayerModeComputationContext
:
...
...
@@ -95,7 +103,7 @@ class LayerScatterModes:
@
classmethod
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
if
context
.
layer_id
==
0
:
return
ScatterMode
.
model_input_
output
()
return
ScatterMode
.
model_input_
mode
()
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
@
classmethod
...
...
@@ -126,7 +134,7 @@ class LayerScatterModes:
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
if
context
.
layer_id
==
context
.
num_layers
-
1
:
return
ScatterMode
.
model_
input_
output
()
return
ScatterMode
.
model_output
_mode
()
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
return
ScatterMode
.
SCATTERED
if
mlp_mode
==
ScatterMode
.
FULL
:
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
a3d7f4b6
...
...
@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
,
enable_tbo
=
True
,
input_data_scatter_mode
=
ScatterMode
.
model_input_
output
(),
input_data_scatter_mode
=
ScatterMode
.
model_input_
mode
(),
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment