Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
52c03f16
"vscode:/vscode.git/clone" did not exist on "9302933b93f573ac92026ccc48b3b0a4df7b1fda"
Unverified
Commit
52c03f16
authored
Jan 27, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 27, 2025
Browse files
Add activation parameters to fused_moe (#3170)
parent
741fccd7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
7 deletions
+52
-7
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+3
-0
python/sglang/srt/layers/moe/fused_moe_native.py
python/sglang/srt/layers/moe/fused_moe_native.py
+17
-3
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+18
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+9
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+4
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-0
test/srt/test_fp8_kernel.py
test/srt/test_fp8_kernel.py
+0
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
52c03f16
...
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
...
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
...
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
self
.
num_expert_group
=
num_expert_group
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedEPMoEMethod
()
...
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
...
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
if
self
.
grouped_gemm_runner
is
None
:
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
...
...
python/sglang/srt/layers/moe/fused_moe_native.py
View file @
52c03f16
...
@@ -8,7 +8,7 @@ from typing import Callable, Optional
...
@@ -8,7 +8,7 @@ from typing import Callable, Optional
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
,
SiluAndMul
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
...
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
...
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
if
activation
==
"silu"
:
x1
=
F
.
silu
(
x1
)
elif
activation
==
"gelu"
:
x1
=
F
.
gelu
(
x1
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
@@ -58,6 +64,7 @@ def moe_forward_native(
...
@@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
...
@@ -84,6 +91,13 @@ def moe_forward_native(
...
@@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
if
activation
==
"silu"
:
act
=
SiluAndMul
()
elif
activation
==
"gelu"
:
act
=
GeluAndMul
()
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
outputs
=
[]
outputs
=
[]
start_idx
=
0
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
...
@@ -96,7 +110,7 @@ def moe_forward_native(
...
@@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight
=
layer
.
w2_weight
[
i
]
layer_w2_weight
=
layer
.
w2_weight
[
i
]
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
SiluAndMul
()
(
gate_up
)
gate_up
=
act
(
gate_up
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
outputs
.
append
(
expert_out
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
start_idx
=
end_idx
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
52c03f16
...
@@ -711,6 +711,7 @@ def inplace_fused_experts(
...
@@ -711,6 +711,7 @@ def inplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -726,6 +727,7 @@ def inplace_fused_experts(
...
@@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
True
,
True
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
...
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -767,6 +770,7 @@ def outplace_fused_experts(
...
@@ -767,6 +770,7 @@ def outplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -782,6 +786,7 @@ def outplace_fused_experts(
...
@@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
False
,
False
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
...
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -824,6 +830,7 @@ def fused_experts(
...
@@ -824,6 +830,7 @@ def fused_experts(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -839,6 +846,7 @@ def fused_experts(
...
@@ -839,6 +846,7 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -855,6 +863,7 @@ def fused_experts(
...
@@ -855,6 +863,7 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
w1_scale
,
w1_scale
,
...
@@ -872,6 +881,7 @@ def fused_experts_impl(
...
@@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -986,7 +996,12 @@ def fused_experts_impl(
...
@@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
if
activation
==
"silu"
:
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
invoke_fused_moe_kernel
(
invoke_fused_moe_kernel
(
intermediate_cache2
,
intermediate_cache2
,
...
@@ -1042,6 +1057,7 @@ def fused_moe(
...
@@ -1042,6 +1057,7 @@ def fused_moe(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
...
@@ -1111,6 +1127,7 @@ def fused_moe(
...
@@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
52c03f16
...
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
...
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
activation
=
activation
,
)
)
def
forward_cuda
(
def
forward_cuda
(
...
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
return
fused_experts_ck
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
...
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
use_presharded_weights
:
bool
=
False
,
use_presharded_weights
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
...
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
correction_bias
=
self
.
correction_bias
,
activation
=
self
.
activation
,
)
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
52c03f16
...
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
...
@@ -763,8 +763,8 @@ class Fp8MoEMethod:
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
...
@@ -785,6 +785,8 @@ class Fp8MoEMethod:
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
return
fused_experts_ck
(
return
fused_experts_ck
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
...
@@ -815,6 +817,7 @@ class Fp8MoEMethod:
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
w1_scale
=
(
layer
.
w13_weight_scale_inv
layer
.
w13_weight_scale_inv
...
...
python/sglang/srt/models/grok.py
View file @
52c03f16
...
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
...
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
activation
=
"gelu"
,
use_presharded_weights
=
use_presharded_weights
,
use_presharded_weights
=
use_presharded_weights
,
)
)
...
...
test/srt/test_fp8_kernel.py
View file @
52c03f16
...
@@ -2,8 +2,6 @@ import unittest
...
@@ -2,8 +2,6 @@ import unittest
import
torch
import
torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
w8a8_block_fp8_matmul
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment