Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3b2680a4
Unverified
Commit
3b2680a4
authored
May 08, 2025
by
fzyzcjy
Committed by
GitHub
May 08, 2025
Browse files
Overlap shared expert and routed expert computations (#5121)
parent
79961afa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
8 deletions
+54
-8
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-1
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+53
-7
No files found.
python/sglang/srt/models/llama.py
View file @
3b2680a4
...
...
@@ -90,7 +90,7 @@ class LlamaMLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
forward_batch
=
None
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
...
...
python/sglang/srt/models/llama4.py
View file @
3b2680a4
...
...
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
...
...
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
top_k
=
config
.
num_experts_per_tok
self
.
device_module
=
torch
.
get_device_module
()
intermediate_size_moe
=
config
.
intermediate_size
self
.
router
=
ReplicatedLinear
(
...
...
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
reduce_results
=
False
,
# We need to do scatter before reduce
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
forward_batch
:
ForwardBatch
):
shared_out
,
routed_out
=
self
.
_forward_core
(
hidden_states
,
forward_batch
.
forward_mode
)
out_aD
=
routed_out
+
shared_out
if
self
.
tp_size
>
1
:
out_aD
=
tensor_model_parallel_all_reduce
(
out_aD
)
return
out_aD
def
_forward_core
(
self
,
hidden_states
,
forward_mode
:
ForwardMode
):
if
hidden_states
.
shape
[
0
]
<
4
:
return
self
.
_forward_core_shared_routed_overlap
(
hidden_states
)
else
:
return
self
.
_forward_core_normal
(
hidden_states
)
def
_forward_core_normal
(
self
,
hidden_states
):
# router_scores: [num_tokens, num_experts]
router_logits
,
_
=
self
.
router
(
hidden_states
)
shared_out
=
self
.
shared_expert
(
hidden_states
)
...
...
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
out_aD
=
routed_out
+
shar
ed_out
return
shared_out
,
rout
ed_out
if
self
.
tp_size
>
1
:
out_aD
=
tensor_model_parallel_all_reduce
(
out_aD
)
def
_forward_core_shared_routed_overlap
(
self
,
hidden_states
)
:
alt_stream
=
_get_or_create_alt_stream
(
self
.
device_module
)
return
out_aD
alt_stream
.
wait_stream
(
self
.
device_module
.
current_stream
())
shared_out
=
self
.
shared_expert
(
hidden_states
)
with
self
.
device_module
.
stream
(
alt_stream
):
# router_scores: [num_tokens, num_experts]
router_logits
,
_
=
self
.
router
(
hidden_states
)
routed_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
self
.
device_module
.
current_stream
().
wait_stream
(
alt_stream
)
return
shared_out
,
routed_out
_alt_stream
=
None
def
_get_or_create_alt_stream
(
device_module
):
global
_alt_stream
if
_alt_stream
is
None
:
_alt_stream
=
device_module
.
Stream
()
return
_alt_stream
class
Llama4Attention
(
nn
.
Module
):
...
...
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states
=
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
feed_forward
(
hidden_states
,
forward_batch
)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment