Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1d24db83
Unverified
Commit
1d24db83
authored
Aug 08, 2025
by
Cheng Wan
Committed by
GitHub
Aug 08, 2025
Browse files
Expert Parallelism for GPT-OSS (#8944)
parent
44401358
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
269 additions
and
119 deletions
+269
-119
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+6
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+101
-12
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+4
-2
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+80
-52
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+9
-2
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+54
-47
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-4
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
1d24db83
...
@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
...
@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
with_bias
:
bool
=
False
,
):
):
super
().
__init__
(
super
().
__init__
(
num_experts
=
num_experts
,
num_experts
=
num_experts
,
...
@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
...
@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
activation
=
activation
,
activation
=
activation
,
# apply_router_weight_on_input=apply_router_weight_on_input,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
with_bias
=
with_bias
,
)
)
self
.
start_expert_id
=
self
.
moe_ep_rank
*
self
.
num_local_experts
self
.
start_expert_id
=
self
.
moe_ep_rank
*
self
.
num_local_experts
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
1d24db83
...
@@ -319,6 +319,7 @@ def fused_moe_kernel(
...
@@ -319,6 +319,7 @@ def fused_moe_kernel(
# Pointers to matrices
# Pointers to matrices
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
bias_ptr
,
c_ptr
,
c_ptr
,
a_scale_ptr
,
a_scale_ptr
,
b_scale_ptr
,
b_scale_ptr
,
...
@@ -340,6 +341,8 @@ def fused_moe_kernel(
...
@@ -340,6 +341,8 @@ def fused_moe_kernel(
stride_be
,
stride_be
,
stride_bk
,
stride_bk
,
stride_bn
,
stride_bn
,
stride_bias_e
,
stride_bias_n
,
stride_cm
,
stride_cm
,
stride_cn
,
stride_cn
,
stride_asm
,
stride_asm
,
...
@@ -449,6 +452,10 @@ def fused_moe_kernel(
...
@@ -449,6 +452,10 @@ def fused_moe_kernel(
+
off_experts
*
stride_be
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
)
if
bias_ptr
is
not
None
:
bias
=
tl
.
load
(
bias_ptr
+
off_experts
*
stride_bias_e
+
offs_bn
[
None
,
:]
*
stride_bias_n
)
if
use_int8_w8a16
:
if
use_int8_w8a16
:
b_scale_ptrs
=
(
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
...
@@ -526,18 +533,20 @@ def fused_moe_kernel(
...
@@ -526,18 +533,20 @@ def fused_moe_kernel(
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
accumulator
*
=
b_scale
elif
use_fp8_w8a8
or
use_int8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
==
0
or
group_n
==
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
accumulator
*=
a_scale
*
b_scale
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
if
bias_ptr
is
not
None
:
else
:
accumulator
+=
bias
accumulator
=
accumulator
.
to
(
compute_type
)
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
*=
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# -----------------------------------------------------------
# Write back the block of the output
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
@@ -622,6 +631,7 @@ def moe_align_block_size(
...
@@ -622,6 +631,7 @@ def moe_align_block_size(
def
invoke_fused_moe_kernel
(
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
C
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
...
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
...
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
):
):
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
assert
bias
is
None
fused_moe_kernel_gptq_awq
[
grid
](
fused_moe_kernel_gptq_awq
[
grid
](
A
,
A
,
B
,
B
,
...
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
...
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
fused_moe_kernel
[
grid
](
fused_moe_kernel
[
grid
](
A
,
A
,
B
,
B
,
bias
,
C
,
C
,
A_scale
,
A_scale
,
B_scale
,
B_scale
,
...
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
...
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
B
.
stride
(
0
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
2
),
B
.
stride
(
1
),
B
.
stride
(
1
),
bias
.
stride
(
0
)
if
bias
is
not
None
else
0
,
bias
.
stride
(
1
)
if
bias
is
not
None
else
0
,
C
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
...
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
...
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
...
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
...
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
fused_experts_impl
(
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
...
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
b1
,
b2
,
True
,
True
,
activation
,
activation
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
...
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
...
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
block_shape
,
block_shape
,
False
,
False
,
routed_scaling_factor
,
routed_scaling_factor
,
activation_alpha
,
swiglu_limit
,
)
)
...
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
...
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
...
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
...
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
pass
pass
...
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
...
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
...
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
...
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
return
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
...
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
b1
,
b2
,
False
,
False
,
activation
,
activation
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
...
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
...
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
)
...
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
...
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
...
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
...
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1157,6 +1195,8 @@ def fused_experts(
...
@@ -1157,6 +1195,8 @@ def fused_experts(
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
...
@@ -1174,6 +1214,8 @@ def fused_experts(
...
@@ -1174,6 +1214,8 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
):
):
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
if
inplace
:
if
inplace
:
...
@@ -1184,6 +1226,8 @@ def fused_experts(
...
@@ -1184,6 +1226,8 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
b1
,
b2
,
activation
,
activation
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_fp8_w8a8
,
...
@@ -1199,6 +1243,8 @@ def fused_experts(
...
@@ -1199,6 +1243,8 @@ def fused_experts(
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
routed_scaling_factor
,
routed_scaling_factor
,
activation_alpha
,
swiglu_limit
,
)
)
return
hidden_states
return
hidden_states
else
:
else
:
...
@@ -1208,6 +1254,8 @@ def fused_experts(
...
@@ -1208,6 +1254,8 @@ def fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
b1
,
b2
,
activation
,
activation
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_fp8_w8a8
,
...
@@ -1224,6 +1272,8 @@ def fused_experts(
...
@@ -1224,6 +1272,8 @@ def fused_experts(
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
)
...
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
...
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
out
.
mul_
(
routed_scaling_factor
)
out
.
mul_
(
routed_scaling_factor
)
@
torch
.
compile
def
swiglu_with_alpha_and_limit
(
x
,
alpha
,
limit
):
gate
,
up
=
x
[...,
::
2
],
x
[...,
1
::
2
]
gate
=
gate
.
clamp
(
min
=
None
,
max
=
limit
)
up
=
up
.
clamp
(
min
=-
limit
,
max
=
limit
)
return
gate
*
torch
.
sigmoid
(
gate
*
alpha
)
*
(
up
+
1
)
def
fused_experts_impl
(
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
...
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
...
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
):
):
padded_size
=
padding_size
padded_size
=
padding_size
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
...
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
...
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
else
:
else
:
assert
(
assert
(
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padded_size
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padded_size
),
"Hidden size mismatch"
),
f
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
...
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel
(
invoke_fused_moe_kernel
(
curr_hidden_states
,
curr_hidden_states
,
w1
,
w1
,
b1
,
intermediate_cache1
,
intermediate_cache1
,
a1_scale
,
a1_scale
,
w1_scale
,
w1_scale
,
...
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
...
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
if
activation
==
"silu"
:
if
activation
==
"silu"
:
if
_is_cuda
:
if
activation_alpha
is
not
None
:
assert
swiglu_limit
is
not
None
intermediate_cache2
=
swiglu_with_alpha_and_limit
(
intermediate_cache1
.
view
(
-
1
,
N
),
activation_alpha
,
swiglu_limit
,
)
elif
_is_cuda
:
silu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
silu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
else
:
vllm_ops
.
silu_and_mul
(
vllm_ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
)
)
elif
activation
==
"gelu"
:
elif
activation
==
"gelu"
:
assert
(
activation_alpha
is
None
),
"activation_alpha is not supported for gelu"
assert
swiglu_limit
is
None
,
"swiglu_limit is not supported for gelu"
if
_is_cuda
:
if
_is_cuda
:
gelu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
gelu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
else
:
...
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
...
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel
(
invoke_fused_moe_kernel
(
intermediate_cache2
,
intermediate_cache2
,
w2
,
w2
,
b2
,
(
(
intermediate_cache3
intermediate_cache3
if
not
no_combine
and
topk_ids
.
shape
[
1
]
!=
1
if
not
no_combine
and
topk_ids
.
shape
[
1
]
!=
1
...
@@ -1567,6 +1642,8 @@ def fused_moe(
...
@@ -1567,6 +1642,8 @@ def fused_moe(
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
b1
:
Optional
[
torch
.
Tensor
]
=
None
,
b2
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
...
@@ -1584,6 +1661,8 @@ def fused_moe(
...
@@ -1584,6 +1661,8 @@ def fused_moe(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -1594,6 +1673,8 @@ def fused_moe(
...
@@ -1594,6 +1673,8 @@ def fused_moe(
- w1 (torch.Tensor): The first set of expert weights.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts.
- topk_output (TopKOutput): The top-k output of the experts.
- b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
...
@@ -1615,6 +1696,10 @@ def fused_moe(
...
@@ -1615,6 +1696,10 @@ def fused_moe(
a2.
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation
function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
function.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -1625,6 +1710,8 @@ def fused_moe(
...
@@ -1625,6 +1710,8 @@ def fused_moe(
w1
,
w1
,
w2
,
w2
,
topk_output
,
topk_output
,
b1
=
b1
,
b2
=
b2
,
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
@@ -1642,4 +1729,6 @@ def fused_moe(
...
@@ -1642,4 +1729,6 @@ def fused_moe(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
)
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
1d24db83
...
@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
,
with_bias
=
with_bias
self
.
use_triton_kernels
)
)
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
...
@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU.
# If we are in EP mode, we need to move the expert map to GPU.
self
.
expert_map_gpu
=
self
.
expert_map_cpu
.
to
(
device
=
"cuda"
)
self
.
expert_map_gpu
=
self
.
expert_map_cpu
.
to
(
device
=
"cuda"
)
if
self
.
expert_map_gpu
is
not
None
:
if
self
.
expert_map_gpu
is
not
None
and
isinstance
(
topk_output
,
StandardTopKOutput
):
topk_output
=
topk_output
.
_replace
(
topk_output
=
topk_output
.
_replace
(
topk_ids
=
self
.
expert_map_gpu
[
topk_output
.
topk_ids
]
topk_ids
=
self
.
expert_map_gpu
[
topk_output
.
topk_ids
]
)
)
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
1d24db83
...
@@ -8,6 +8,7 @@ import logging
...
@@ -8,6 +8,7 @@ import logging
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
import
triton.language
as
tl
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
...
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
is_cuda
,
is_cuda
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_triton_kernels_available
,
log_info_on_rank0
,
log_info_on_rank0
,
next_power_of_2
,
next_power_of_2
,
round_up
,
round_up
,
...
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
...
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
)
)
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
has_triton_kernels
=
i
mportlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
has_triton_kernels
=
i
s_triton_kernels_available
()
if
is_flashinfer_available
():
if
is_flashinfer_available
():
...
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
...
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
):
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
use_flashinfer
=
global_server_args_dict
.
get
(
return
Mxfp4MoEMethod
(
prefix
)
"enable_flashinfer_mxfp4_moe"
,
False
)
return
Mxfp4MoEMethod
(
use_triton_kernels
=
True
,
with_bias
=
True
,
use_flashinfer
=
use_flashinfer
)
else
:
else
:
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
return
None
return
None
...
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def
__init__
(
def
__init__
(
self
,
self
,
use_triton_kernels
:
bool
=
True
,
prefix
:
str
,
with_bias
:
bool
=
True
,
use_flashinfer
:
bool
=
False
,
):
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
super
().
__init__
()
super
().
__init__
()
self
.
topk_indices_dtype
=
None
self
.
topk_indices_dtype
=
None
self
.
use_triton_kernels
=
us
e_triton_kernel
s
self
.
use_triton_kernels
=
global_server_args_dict
[
"enabl
e_triton_kernel
_moe"
]
self
.
with_bias
=
with_bias
self
.
with_bias
=
False
self
.
use_flashinfer
=
use_flashinfer
self
.
use_flashinfer
=
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
...
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
with_bias
:
bool
=
False
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
# print(f"hi {self=} create_weights {layer=}")
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
weight_dtype
=
torch
.
uint8
weight_dtype
=
torch
.
uint8
scale_dtype
=
torch
.
uint8
scale_dtype
=
torch
.
uint8
self
.
with_bias
=
with_bias
mxfp4_block
=
32
mxfp4_block
=
32
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# pad the intermediate size to be a multiple of 2 * mxfp4_block
...
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num
_experts
,
layer
.
num_local
_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
2
,
hidden_size
//
2
,
dtype
=
weight_dtype
,
dtype
=
weight_dtype
,
...
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num
_experts
,
layer
.
num_local
_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
mxfp4_block
,
hidden_size
//
mxfp4_block
,
dtype
=
scale_dtype
,
dtype
=
scale_dtype
,
...
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_bias
=
torch
.
nn
.
Parameter
(
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num
_experts
,
layer
.
num_local
_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
),
),
...
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
2
,
intermediate_size_per_partition_after_pad
//
2
,
dtype
=
weight_dtype
,
dtype
=
weight_dtype
,
...
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
mxfp4_block
,
intermediate_size_per_partition_after_pad
//
mxfp4_block
,
dtype
=
scale_dtype
,
dtype
=
scale_dtype
,
...
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w2_weight_bias
=
torch
.
nn
.
Parameter
(
w2_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num
_experts
,
hidden_size
,
dtype
=
torch
.
bfloat16
),
torch
.
zeros
(
layer
.
num_local
_experts
,
hidden_size
,
dtype
=
torch
.
bfloat16
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight_bias"
,
w2_weight_bias
)
layer
.
register_parameter
(
"w2_weight_bias"
,
w2_weight_bias
)
...
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
)
return
return
from
triton_kernels
.matmul_ogs
import
FlexCtx
,
PrecisionConfig
if
self
.
use_
triton_kernels
:
w13_weight_bias
=
layer
.
w13_weight_bias
.
to
(
torch
.
float32
)
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
w2_weight_bias
=
layer
.
w2_weight_bias
.
to
(
torch
.
float32
)
layer
.
w13_weight_bias
=
Paramet
er
(
w13_weight_bias
,
requires_grad
=
False
)
w13_weight_bias
=
lay
er
.
w13_weight_bias
.
to
(
torch
.
float32
)
layer
.
w2_weight_bias
=
Paramet
er
(
w2_weight_bias
,
requires_grad
=
False
)
w2_weight_bias
=
lay
er
.
w2_weight_bias
.
to
(
torch
.
float32
)
num_warps
=
8
layer
.
w13_weight_bias
=
Parameter
(
w13_weight_bias
,
requires_grad
=
False
)
layer
.
w2_weight_bias
=
Parameter
(
w2_weight_bias
,
requires_grad
=
False
)
w13_weight
,
w13_flex
,
w13_scale
=
_swizzle_mxfp4
(
num_warps
=
8
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
num_warps
)
w2_weight
,
w2_flex
,
w2_scale
=
_swizzle_mxfp4
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
num_warps
)
self
.
w13_precision_config
=
PrecisionConfig
(
w13_weight
,
w13_flex
,
w13_scale
=
_swizzle_mxfp4
(
weight_scale
=
w13_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w13_flex
)
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
num_warps
)
)
self
.
w2_precision_config
=
PrecisionConfig
(
w2_weight
,
w2_flex
,
w2_scale
=
_swizzle_mxfp4
(
weight_scale
=
w2_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
)
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
num_warps
)
)
self
.
w13_precision_config
=
PrecisionConfig
(
weight_scale
=
w13_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w13_flex
)
)
self
.
w2_precision_config
=
PrecisionConfig
(
weight_scale
=
w2_scale
,
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
)
)
self
.
w13_weight_triton_tensor
=
w13_weight
self
.
w13_weight_triton_tensor
=
w13_weight
self
.
w2_weight_triton_tensor
=
w2_weight
self
.
w2_weight_triton_tensor
=
w2_weight
del
layer
.
w13_weight
del
layer
.
w2_weight
else
:
from
triton_kernels.numerics_details.mxfp
import
upcast_from_mxfp
# need to delete the original weights to save memory on single GPU
w13_weight
=
upcast_from_mxfp
(
del
layer
.
w13_weight
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
dtype
=
torch
.
bfloat16
,
axis
=-
1
del
layer
.
w2_weight
)
layer
.
w13_weight
=
None
w2_weight
=
upcast_from_mxfp
(
layer
.
w2_weight
=
None
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
dtype
=
torch
.
bfloat16
,
axis
=-
1
)
del
layer
.
w13_weight
del
layer
.
w2_weight
del
layer
.
w13_weight_scale
del
layer
.
w2_weight_scale
layer
.
w13_weight
=
Parameter
(
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
.
data
,
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
_get_tile_tokens_dim
(
self
,
x
:
torch
.
Tensor
,
top_k
:
int
):
def
_get_tile_tokens_dim
(
self
,
x
:
torch
.
Tensor
,
top_k
:
int
):
...
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None
,
# output1_scale_scalar
None
,
# output1_scale_scalar
None
,
# output1_scale_gate_scalar
None
,
# output1_scale_gate_scalar
None
,
# output2_scale_scalar
None
,
# output2_scale_scalar
self
.
num_experts
,
layer
.
num_experts
,
top_k
,
top_k
,
None
,
# n_group
None
,
# n_group
None
,
# topk_group
None
,
# topk_group
self
.
intermediate_size
,
# padded to multiple of 256
self
.
intermediate_size
,
# padded to multiple of 256
0
,
# local_expert_offset
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
self
.
num
_experts
,
# local num experts
layer
.
num_local
_experts
,
# local num experts
None
,
None
,
self
.
_get_tile_tokens_dim
(
x
,
top_k
),
self
.
_get_tile_tokens_dim
(
x
,
top_k
),
1
,
# routing_method_type, renormalize
1
,
# routing_method_type, renormalize
...
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return
trtllm_gen_output
return
trtllm_gen_output
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
assert
(
layer
.
moe_ep_size
==
1
),
"Expert parallel is not supported when using triton kernels"
if
self
.
with_bias
:
if
self
.
with_bias
:
# TODO why we do not put weights on layer?
assert
layer
.
w13_weight
is
None
assert
layer
.
w2_weight
is
None
return
self
.
triton_kernel_moe_with_bias_forward
(
return
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
self
.
w13_weight_triton_tensor
,
w1
=
self
.
w13_weight_triton_tensor
,
...
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output
=
topk_output
,
topk_output
=
topk_output
,
)
)
else
:
else
:
raise
NotImplementedError
()
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
python/sglang/srt/layers/quantization/unquant.py
View file @
1d24db83
...
@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
,
with_bias
:
bool
=
False
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
with_bias
=
False
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
...
@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
with_bias
:
bool
=
False
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
self
.
with_bias
=
with_bias
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight_n
,
w13_weight_k
=
2
*
intermediate_size
,
hidden_size
w13_weight_n
,
w13_weight_k
=
2
*
intermediate_size
,
hidden_size
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
...
@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
b1
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
b2
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
topk_output
=
topk_output
,
topk_output
=
topk_output
,
inplace
=
inplace
and
not
no_combine
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
1d24db83
...
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
...
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
,
get_moe_expert_parallel_world_size
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_world_size
,
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
...
@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
self
.
activation
=
config
.
hidden_act
self
.
activation
=
config
.
hidden_act
self
.
activation_alpha
=
getattr
(
config
,
"hidden_act_alpha"
,
1.702
)
self
.
activation_alpha
=
getattr
(
config
,
"hidden_act_alpha"
,
1.702
)
self
.
swiglu_limit
=
config
.
swiglu_limit
self
.
swiglu_limit
=
config
.
swiglu_limit
if
self
.
tp_size
>
config
.
num_local_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_local_experts
}
."
)
if
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]:
if
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]:
self
.
topk
=
None
self
.
topk
=
None
...
@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
...
@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
mxfp4_block
=
32
mxfp4_block
=
32
tp_rank
=
get_tensor_model_parallel_rank
()
moe_tp_rank
=
get_moe_tensor_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
moe_tp_size
=
get_moe_tensor_parallel_world_size
()
moe_ep_rank
=
get_moe_expert_parallel_rank
()
moe_ep_size
=
get_moe_expert_parallel_world_size
()
intermediate_size
=
self
.
config
.
intermediate_size
intermediate_size
=
self
.
config
.
intermediate_size
intermediate_size_block
=
intermediate_size
//
mxfp4_block
intermediate_size_block
=
intermediate_size
//
mxfp4_block
per_rank_intermediate_size_block
=
intermediate_size_block
//
tp_size
per_rank_intermediate_size_block
=
intermediate_size_block
//
moe_
tp_size
per_rank_intermediate_size
=
per_rank_intermediate_size_block
*
mxfp4_block
per_rank_intermediate_size
=
per_rank_intermediate_size_block
*
mxfp4_block
# Calculate common slicing bounds for current rank
# Calculate common slicing bounds for current rank
tp_rank_start
=
tp_rank
*
per_rank_intermediate_size
assert
self
.
config
.
num_local_experts
%
moe_ep_size
==
0
tp_rank_end
=
min
((
tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
moe_num_global_experts
=
self
.
config
.
num_local_experts
moe_num_local_experts
=
self
.
config
.
num_local_experts
//
moe_ep_size
# Attention heads per rank
moe_tp_rank_start
=
moe_tp_rank
*
per_rank_intermediate_size
heads_per_rank
=
self
.
config
.
num_attention_heads
//
tp_size
moe_tp_rank_end
=
min
(
head_start
=
tp_rank
*
heads_per_rank
(
moe_tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
num_experts
=
self
.
config
.
num_local_experts
moe_ep_rank_start
=
moe_ep_rank
*
moe_num_local_experts
moe_ep_rank_end
=
(
moe_ep_rank
+
1
)
*
moe_num_local_experts
for
name
,
weight
in
weights
:
for
name
,
weight
in
weights
:
weight
=
weight
.
cuda
()
weight
=
weight
.
cuda
()
...
@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
...
@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
# flat weight from (E, 2 * N, block_size, entry_per_block)
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight
=
weight
.
view
(
weight
=
weight
.
view
(
num
_experts
,
2
*
intermediate_size
,
-
1
moe_num_global
_experts
,
2
*
intermediate_size
,
-
1
).
contiguous
()
).
contiguous
()
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
narrow_weight
=
weight
[
moe_ep_rank_start
:
moe_ep_rank_end
,
2
*
moe_tp_rank_start
:
2
*
moe_tp_rank_end
,
...,
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
...
@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
# same flatten here, but since 2 mx4 value are packed in 1
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
# uint8, divide by 2
weight
=
weight
.
view
(
weight
=
weight
.
view
(
num
_experts
,
-
1
,
intermediate_size
//
2
moe_num_global
_experts
,
-
1
,
intermediate_size
//
2
).
contiguous
()
).
contiguous
()
narrow_weight
=
weight
[...,
tp_rank_start
//
2
:
tp_rank_end
//
2
]
narrow_weight
=
weight
[
moe_ep_rank_start
:
moe_ep_rank_end
,
...,
moe_tp_rank_start
//
2
:
moe_tp_rank_end
//
2
,
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
...
@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
elif
"gate_up_proj_scales"
in
name
:
elif
"gate_up_proj_scales"
in
name
:
# Handle MLP gate and up projection weights scale
# Handle MLP gate and up projection weights scale
new_name
=
name
.
replace
(
"gate_up_proj_scales"
,
"w13_weight_scale"
)
new_name
=
name
.
replace
(
"gate_up_proj_scales"
,
"w13_weight_scale"
)
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
narrow_weight
=
weight
[
moe_ep_rank_start
:
moe_ep_rank_end
,
2
*
moe_tp_rank_start
:
2
*
moe_tp_rank_end
,
...,
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
...
@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP down projection weights
# Handle MLP down projection weights
new_name
=
name
.
replace
(
"down_proj_scales"
,
"w2_weight_scale"
)
new_name
=
name
.
replace
(
"down_proj_scales"
,
"w2_weight_scale"
)
narrow_weight
=
weight
[
narrow_weight
=
weight
[
...,
tp_rank_start
//
mxfp4_block
:
tp_rank_end
//
mxfp4_block
moe_ep_rank_start
:
moe_ep_rank_end
,
...,
moe_tp_rank_start
//
mxfp4_block
:
moe_tp_rank_end
//
mxfp4_block
,
]
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
...
@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
...
@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP gate and up projection biases
# Handle MLP gate and up projection biases
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_weight_bias"
)
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_weight_bias"
)
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
narrow_weight
=
weight
[
moe_ep_rank_start
:
moe_ep_rank_end
,
2
*
moe_tp_rank_start
:
2
*
moe_tp_rank_end
,
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
...
@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
loaded_params
.
add
(
new_name
)
loaded_params
.
add
(
new_name
)
elif
"down_proj_bias"
in
name
:
elif
"down_proj_bias"
in
name
:
if
get_moe_tensor_parallel_rank
()
!=
0
:
narrow_weight
=
weight
[
moe_ep_rank_start
:
moe_ep_rank_end
,
...]
weight
=
torch
.
zeros_like
(
weight
)
if
moe_tp_rank
!=
0
:
narrow_weight
=
torch
.
zeros_like
(
narrow_weight
)
# Handle MLP down projection bias
# Handle MLP down projection bias
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_weight_bias"
)
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_weight_bias"
)
param
=
params_dict
[
new_name
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
weight_loader
(
param
,
weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
,
)
)
loaded_params
.
add
(
new_name
)
loaded_params
.
add
(
new_name
)
...
@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
...
@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
get_moe_impl_class
().
make_expert_params_mapping_fused
(
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"mxfp4"
):
ckpt_gate_up_proj_name
=
"gate_up_proj"
,
expert_params_mapping
=
(
ckpt_down_proj_name
=
"down_proj"
,
get_moe_impl_class
().
make_expert_params_mapping_fused_mxfp4
(
ckpt_gate_up_proj_bias_name
=
"gate_up_proj_bias"
,
ckpt_gate_up_proj_name
=
"gate_up_proj_blocks"
,
ckpt_down_proj_bias_name
=
"down_proj_bias"
,
ckpt_down_proj_name
=
"down_proj_blocks"
,
)
ckpt_gate_up_proj_bias_name
=
"gate_up_proj_bias"
,
ckpt_down_proj_bias_name
=
"down_proj_bias"
,
ckpt_gate_up_proj_scale_name
=
"gate_up_proj_scales"
,
ckpt_down_proj_scale_name
=
"down_proj_scales"
,
)
)
else
:
expert_params_mapping
=
(
get_moe_impl_class
().
make_expert_params_mapping_fused
(
ckpt_gate_up_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_gate_up_proj_bias_name
=
"gate_up_proj_bias"
,
ckpt_down_proj_bias_name
=
"down_proj_bias"
,
)
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
params_checker
=
{
k
:
False
for
k
,
v
in
params_dict
.
items
()}
params_checker
=
{
k
:
False
for
k
,
v
in
params_dict
.
items
()}
...
...
python/sglang/srt/server_args.py
View file @
1d24db83
...
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
...
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip
,
is_hip
,
is_port_available
,
is_port_available
,
is_remote_url
,
is_remote_url
,
is_triton_kernels_available
,
is_valid_ipv6_address
,
is_valid_ipv6_address
,
nullable_str
,
nullable_str
,
)
)
...
@@ -492,10 +493,15 @@ class ServerArgs:
...
@@ -492,10 +493,15 @@ class ServerArgs:
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
)
else
:
else
:
self
.
enable_triton_kernel_moe
=
True
if
self
.
enable_triton_kernel_moe
:
logger
.
info
(
assert
(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
self
.
ep_size
==
1
)
),
"Triton kernel MoE is only supported when ep_size == 1"
if
not
self
.
enable_triton_kernel_moe
and
self
.
ep_size
==
1
:
self
.
enable_triton_kernel_moe
=
True
logger
.
info
(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self
.
disable_hybrid_swa_memory
=
True
self
.
disable_hybrid_swa_memory
=
True
...
...
python/sglang/srt/utils.py
View file @
1d24db83
...
@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
...
@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
"""
"""
self
.
wait_for
(
lambda
count
:
count
==
0
)
self
.
wait_for
(
lambda
count
:
count
==
0
)
@
lru_cache
(
maxsize
=
1
)
def
is_triton_kernels_available
()
->
bool
:
return
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment