Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52c03f16
Unverified
Commit
52c03f16
authored
Jan 27, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 27, 2025
Browse files
Add activation parameters to fused_moe (#3170)
parent
741fccd7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
7 deletions
+52
-7
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+3
-0
python/sglang/srt/layers/moe/fused_moe_native.py
python/sglang/srt/layers/moe/fused_moe_native.py
+17
-3
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+18
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+9
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+4
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-0
test/srt/test_fp8_kernel.py
test/srt/test_fp8_kernel.py
+0
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
52c03f16
...
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
...
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
...
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
self
.
num_expert_group
=
num_expert_group
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
...
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
...
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
if
self
.
grouped_gemm_runner
is
None
:
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
...
...
python/sglang/srt/layers/moe/fused_moe_native.py
View file @
52c03f16
...
@@ -8,7 +8,7 @@ from typing import Callable, Optional
...
@@ -8,7 +8,7 @@ from typing import Callable, Optional
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
,
SiluAndMul
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
...
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
...
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
if
activation
==
"silu"
:
x1
=
F
.
silu
(
x1
)
elif
activation
==
"gelu"
:
x1
=
F
.
gelu
(
x1
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
@@ -58,6 +64,7 @@ def moe_forward_native(
...
@@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
...
@@ -84,6 +91,13 @@ def moe_forward_native(
...
@@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
if
activation
==
"silu"
:
act
=
SiluAndMul
()
elif
activation
==
"gelu"
:
act
=
GeluAndMul
()
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
outputs
=
[]
outputs
=
[]
start_idx
=
0
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
...
@@ -96,7 +110,7 @@ def moe_forward_native(
...
@@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight
=
layer
.
w2_weight
[
i
]
layer_w2_weight
=
layer
.
w2_weight
[
i
]
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
SiluAndMul
()
(
gate_up
)
gate_up
=
act
(
gate_up
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
outputs
.
append
(
expert_out
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
start_idx
=
end_idx
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
52c03f16
...
@@ -711,6 +711,7 @@ def inplace_fused_experts(
...
@@ -711,6 +711,7 @@ def inplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -726,6 +727,7 @@ def inplace_fused_experts(
...
@@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
True
,
True
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
...
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -767,6 +770,7 @@ def outplace_fused_experts(
...
@@ -767,6 +770,7 @@ def outplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -782,6 +786,7 @@ def outplace_fused_experts(
...
@@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
False
,
False
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
...
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -824,6 +830,7 @@ def fused_experts(
...
@@ -824,6 +830,7 @@ def fused_experts(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -839,6 +846,7 @@ def fused_experts(
...
@@ -839,6 +846,7 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -855,6 +863,7 @@ def fused_experts(
...
@@ -855,6 +863,7 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -872,6 +881,7 @@ def fused_experts_impl(
...
@@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -986,7 +996,12 @@ def fused_experts_impl(
...
@@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
if
activation
==
"silu"
:
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
invoke_fused_moe_kernel
(
invoke_fused_moe_kernel
(
intermediate_cache2
,
intermediate_cache2
,
...
@@ -1042,6 +1057,7 @@ def fused_moe(
...
@@ -1042,6 +1057,7 @@ def fused_moe(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
...
@@ -1111,6 +1127,7 @@ def fused_moe(
...
@@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
52c03f16
...
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
...
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
activation
=
activation
,
)
)
def
forward_cuda
(
def
forward_cuda
(
...
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
return
fused_experts_ck
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
...
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
use_presharded_weights
:
bool
=
False
,
use_presharded_weights
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
correction_bias
=
self
.
correction_bias
,
activation
=
self
.
activation
,
)
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
52c03f16
...
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
...
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
...
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
return
fused_experts_ck
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
...
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
w1_scale
=
(
layer
.
w13_weight_scale_inv
layer
.
w13_weight_scale_inv
...
...
python/sglang/srt/models/grok.py
View file @
52c03f16
...
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
...
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
activation
=
"gelu"
,
use_presharded_weights
=
use_presharded_weights
,
use_presharded_weights
=
use_presharded_weights
,
)
)
...
...
test/srt/test_fp8_kernel.py
View file @
52c03f16
...
@@ -2,8 +2,6 @@ import unittest
...
@@ -2,8 +2,6 @@ import unittest
import
torch
import
torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
w8a8_block_fp8_matmul
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment