Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fa6723f0
Unverified
Commit
fa6723f0
authored
May 27, 2025
by
Yineng Zhang
Committed by
GitHub
May 27, 2025
Browse files
Revert "fix communicator for non-dp lm head (#6662)" (#6677)
parent
673ff668
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
13 deletions
+5
-13
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+4
-12
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
No files found.
python/sglang/srt/layers/communicator.py
View file @
fa6723f0
...
@@ -50,18 +50,10 @@ class ScatterMode(Enum):
...
@@ -50,18 +50,10 @@ class ScatterMode(Enum):
FULL
=
auto
()
FULL
=
auto
()
@
staticmethod
@
staticmethod
def
model_input_
mode
():
def
model_input_
output
():
"""The scatter mode for model
in
put data"""
"""The scatter mode for model
forward pass input and out
put data"""
return
ScatterMode
.
TP_ATTN_FULL
return
ScatterMode
.
TP_ATTN_FULL
@
staticmethod
def
model_output_mode
():
"""The scatter mode for model output data"""
if
global_server_args_dict
[
"enable_dp_lm_head"
]:
return
ScatterMode
.
TP_ATTN_FULL
else
:
return
ScatterMode
.
FULL
@
dataclass
@
dataclass
class
_LayerModeComputationContext
:
class
_LayerModeComputationContext
:
...
@@ -103,7 +95,7 @@ class LayerScatterModes:
...
@@ -103,7 +95,7 @@ class LayerScatterModes:
@
classmethod
@
classmethod
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
if
context
.
layer_id
==
0
:
if
context
.
layer_id
==
0
:
return
ScatterMode
.
model_input_
mode
()
return
ScatterMode
.
model_input_
output
()
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
@
classmethod
@
classmethod
...
@@ -134,7 +126,7 @@ class LayerScatterModes:
...
@@ -134,7 +126,7 @@ class LayerScatterModes:
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
if
context
.
layer_id
==
context
.
num_layers
-
1
:
if
context
.
layer_id
==
context
.
num_layers
-
1
:
return
ScatterMode
.
model_output
_mode
()
return
ScatterMode
.
model_
input_
output
()
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
return
ScatterMode
.
SCATTERED
return
ScatterMode
.
SCATTERED
if
mlp_mode
==
ScatterMode
.
FULL
:
if
mlp_mode
==
ScatterMode
.
FULL
:
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
fa6723f0
...
@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states
,
residual
=
model_forward_maybe_tbo
(
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
,
layers
=
self
.
layers
,
enable_tbo
=
True
,
enable_tbo
=
True
,
input_data_scatter_mode
=
ScatterMode
.
model_input_
mode
(),
input_data_scatter_mode
=
ScatterMode
.
model_input_
output
(),
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment