Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52c03f16
Unverified
Commit
52c03f16
authored
Jan 27, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 27, 2025
Browse files
Add activation parameters to fused_moe (#3170)
parent
741fccd7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
7 deletions
+52
-7
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+3
-0
python/sglang/srt/layers/moe/fused_moe_native.py
python/sglang/srt/layers/moe/fused_moe_native.py
+17
-3
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+18
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+9
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+4
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-0
test/srt/test_fp8_kernel.py
test/srt/test_fp8_kernel.py
+0
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
52c03f16
...
...
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
):
super
().
__init__
()
...
...
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
...
...
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
...
...
python/sglang/srt/layers/moe/fused_moe_native.py
View file @
52c03f16
...
...
@@ -8,7 +8,7 @@ from typing import Callable, Optional
import
torch
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
,
SiluAndMul
from
sglang.srt.layers.moe.topk
import
select_experts
...
...
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
...
...
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
if
activation
==
"silu"
:
x1
=
F
.
silu
(
x1
)
elif
activation
==
"gelu"
:
x1
=
F
.
gelu
(
x1
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
...
@@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
...
...
@@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
if
activation
==
"silu"
:
act
=
SiluAndMul
()
elif
activation
==
"gelu"
:
act
=
GeluAndMul
()
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
...
...
@@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight
=
layer
.
w2_weight
[
i
]
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
SiluAndMul
()
(
gate_up
)
gate_up
=
act
(
gate_up
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
52c03f16
...
...
@@ -711,6 +711,7 @@ def inplace_fused_experts(
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights
,
topk_ids
,
True
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
...
...
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -767,6 +770,7 @@ def outplace_fused_experts(
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights
,
topk_ids
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
...
...
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -824,6 +830,7 @@ def fused_experts(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -839,6 +846,7 @@ def fused_experts(
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
...
...
@@ -855,6 +863,7 @@ def fused_experts(
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
...
...
@@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape
=
block_shape
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
if
activation
==
"silu"
:
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
...
...
@@ -1042,6 +1057,7 @@ def fused_moe(
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
...
...
@@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights
,
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
52c03f16
...
...
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
...
...
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
activation
=
activation
,
)
def
forward_cuda
(
...
...
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
...
...
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import
ater
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
)
def
forward_cpu
(
...
...
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
use_presharded_weights
:
bool
=
False
,
):
super
().
__init__
()
...
...
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
...
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
activation
=
self
.
activation
,
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
52c03f16
...
...
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
...
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
import
ater
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
x
,
layer
.
w13_weight
,
...
...
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
...
...
python/sglang/srt/models/grok.py
View file @
52c03f16
...
...
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
activation
=
"gelu"
,
use_presharded_weights
=
use_presharded_weights
,
)
...
...
test/srt/test_fp8_kernel.py
View file @
52c03f16
...
...
@@ -2,8 +2,6 @@ import unittest
import
torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment