Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
29589512
Unverified
Commit
29589512
authored
Aug 14, 2025
by
Cheng Wan
Committed by
GitHub
Aug 14, 2025
Browse files
[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
parent
584e1ab2
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
172 additions
and
303 deletions
+172
-303
python/sglang/srt/layers/quantization/base_config.py
python/sglang/srt/layers/quantization/base_config.py
+2
-6
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+4
-12
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+7
-14
python/sglang/srt/layers/quantization/fp4.py
python/sglang/srt/layers/quantization/fp4.py
+13
-30
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+24
-24
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+1
-0
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+5
-4
python/sglang/srt/layers/quantization/marlin_utils.py
python/sglang/srt/layers/quantization/marlin_utils.py
+4
-3
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+23
-34
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+10
-15
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+9
-25
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+27
-69
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+7
-8
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+5
-13
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-13
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-14
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-3
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+12
-6
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+2
-1
No files found.
python/sglang/srt/layers/quantization/base_config.py
View file @
29589512
...
...
@@ -9,6 +9,7 @@ import torch
from
torch
import
nn
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
...
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
29589512
...
...
@@ -3,7 +3,7 @@
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
...
...
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
moe_runner_config
=
moe_runner_config
,
use_int8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
29589512
...
...
@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
CompressedTensorsConfig
,
...
...
@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_experts
...
...
@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
...
...
@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
routed_scaling_factor
=
routed_scaling_factor
,
)
...
...
@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
**
kwargs
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
topk_weights
,
topk_ids
,
router_logits
=
topk_output
...
...
python/sglang/srt/layers/quantization/fp4.py
View file @
29589512
...
...
@@ -41,6 +41,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
return
out
class
MxFp4MoEMethod
:
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
class
MxFp4MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
Mxfp4Config
):
self
.
quant_config
=
quant_config
@
staticmethod
def
get_moe_method
(
...
...
@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
,
_
=
topk_output
...
...
@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
ActivationType
.
Silu
if
moe_runner_config
.
activation
==
"silu"
else
ActivationType
.
Gelu
),
doweight_stage1
=
False
,
)
...
...
@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
,
_
=
topk_output
...
...
@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
ActivationType
.
Silu
if
moe_runner_config
.
activation
==
"silu"
else
ActivationType
.
Gelu
),
doweight_stage1
=
False
,
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
29589512
...
...
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
...
...
@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
...
...
@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
,
x
,
topk_output
,
activation
,
no_combine
,
moe_runner_config
.
activation
,
moe_runner_config
.
no_combine
,
)
if
ret
is
not
None
:
return
ret
...
...
@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_fp8_blockscale
=
True
,
)
# TODO: Fuse into select_experts
if
routed_scaling_factor
is
not
None
:
output
*=
routed_scaling_factor
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
# Expert fusion with FP8 quantization
return
fused_experts
(
...
...
@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
...
...
@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
apply_with_router_logits
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
activation
=
moe_runner_config
.
activation
routed_scaling_factor
=
moe_runner_config
.
routed_scaling_factor
from
flashinfer.fused_moe
import
trtllm_fp8_block_scale_moe
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
assert
TopKOutputChecker
.
format_is_bypassed
(
topk_output
)
router_logits
=
topk_output
.
router_logits
topk_config
=
topk_output
.
topk_config
assert
(
activation
==
"silu"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
a_q
,
a_sf
=
per_token_group_quant_fp8
(
x
,
self
.
quant_config
.
weight_block_size
[
1
])
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
from
flashinfer.fused_moe
import
trtllm_fp8_block_scale_moe
return
trtllm_fp8_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
...
...
@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
gemm2_weights
=
layer
.
w2_weight
,
gemm2_weights_scale
=
layer
.
w2_weight_scale_inv
,
num_experts
=
layer
.
num_experts
,
top_k
=
layer
.
top_k
,
n_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
top_k
=
topk_config
.
top_k
,
n_group
=
topk_config
.
num_expert_group
,
topk_group
=
topk_config
.
topk_group
,
intermediate_size
=
layer
.
w2_weight
.
shape
[
2
],
local_expert_offset
=
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
local_num_experts
=
layer
.
num_local_experts
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
29589512
...
...
@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
return
weight
,
weight_scale
,
input_scale
# TODO(ch-wan): define these backends in --moe-runner-backend
def
cutlass_block_fp8_supported
()
->
bool
:
if
not
get_bool_env_var
(
"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"
):
return
False
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
29589512
...
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.utils
import
is_cuda
...
...
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
**
kwargs
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
# Delay the import to avoid circular dependency
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
# The input must currently be float16
orig_dtype
=
x
.
dtype
...
...
python/sglang/srt/layers/quantization/marlin_utils.py
View file @
29589512
...
...
@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
if
TYPE_CHECKING
:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
try
:
from
vllm
import
_custom_ops
as
ops
...
...
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
)[
0
]
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
->
bool
:
def
check_moe_marlin_supports_layer
(
layer
:
FusedMoE
,
group_size
:
int
)
->
bool
:
hidden_size
=
layer
.
hidden_size
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight
=
not
layer
.
apply_router_weight_on_input
supports_router_weight
=
not
layer
.
moe_runner_config
.
apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation
=
layer
.
activation
==
"silu"
supports_activation
=
layer
.
moe_runner_config
.
activation
==
"silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
29589512
...
...
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
...
...
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
is_cuda
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
if
is_cuda
():
...
...
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
# ModelOpt uses per-tensor quantization
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
)
...
...
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@
property
def
enable_flashinfer_cutlass_moe
(
self
)
->
bool
:
from
sglang.srt.layers.moe
import
get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting."""
return
g
lobal_server_args_dict
.
get
(
"enable
_flashinfer_cutlass
_moe"
,
False
)
return
g
et_moe_runner_backend
().
is
_flashinfer_cutlass
(
)
def
create_weights
(
self
,
...
...
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
ep_rank
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
tp_rank
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
...
...
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if
self
.
enable_flashinfer_cutlass_moe
:
assert
(
not
apply_router_weight_on_input
not
moe_runner_config
.
apply_router_weight_on_input
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
...
...
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer
.
w2_blockscale_swizzled
.
view
(
torch
.
int32
),
layer
.
g2_alphas
,
],
ep_size
=
ep_size
,
ep_rank
=
ep_rank
,
tp_size
=
tp_size
,
tp_rank
=
tp_rank
,
ep_size
=
layer
.
moe_
ep_size
,
ep_rank
=
layer
.
moe_
ep_rank
,
tp_size
=
layer
.
moe_
tp_size
,
tp_rank
=
layer
.
moe_
tp_rank
,
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
)[
0
]
if
routed_scaling_factor
is
not
None
:
output
*=
routed_scaling_factor
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
...
...
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
params
=
layer
.
cutlass_moe_params
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
if
routed_scaling_factor
is
not
None
:
output
*=
routed_scaling_factor
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
29589512
...
...
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
...
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
# avoid circular import
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
...
...
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
moe_runner_config
=
moe_runner_config
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
layer
.
w13_scales
,
...
...
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
],
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
@
staticmethod
...
...
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
)
if
"w13_qzeros"
in
weight_name
:
tensor
=
loaded_weight
.
view
(
layer
.
tp_size
,
-
1
,
loaded_weight
.
size
(
1
))[
tp_rank
]
tensor
=
loaded_weight
.
view
(
layer
.
moe_tp_size
,
-
1
,
loaded_weight
.
size
(
1
)
)[
tp_rank
]
if
shard_id
==
"w1"
:
param
.
data
[
expert_id
,
:
shard_size
//
2
]
=
tensor
else
:
param
.
data
[
expert_id
,
shard_size
//
2
:]
=
tensor
elif
"w2_qzeros"
in
weight_name
:
param
.
data
[
expert_id
]
=
loaded_weight
.
view
(
loaded_weight
.
size
(
0
),
layer
.
tp_size
,
-
1
loaded_weight
.
size
(
0
),
layer
.
moe_
tp_size
,
-
1
)[:,
tp_rank
]
else
:
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
29589512
...
...
@@ -16,14 +16,13 @@
from
__future__
import
annotations
import
importlib.util
import
logging
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
triton.language
as
tl
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe.utils
import
get_moe_runner_backend
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
...
...
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
)
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
...
...
@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
OCP_MX_BLOCK_SIZE
=
32
...
...
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
,
prefix
:
str
,
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
super
().
__init__
()
self
.
prefix
=
prefix
self
.
topk_indices_dtype
=
None
self
.
use_triton_kernels
=
g
lobal_server_args_dict
[
"enable
_triton_kernel
_moe"
]
self
.
use_triton_kernels
=
g
et_moe_runner_backend
().
is
_triton_kernel
()
self
.
with_bias
=
False
self
.
use_flashinfer
=
g
lobal_server_args_dict
[
"enable
_flashinfer_mxfp4
_moe"
]
self
.
use_flashinfer
=
g
et_moe_runner_backend
().
is
_flashinfer_mxfp4
()
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
...
...
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger
,
f
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer:
{
self
.
prefix
}
), it might take a while..."
,
)
# TODO: these values are hardcoded for now, we need to get them from the model
layer
.
gemm1_alpha
=
Parameter
(
torch
.
tensor
([
1.702
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
...
...
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
...
...
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
moe_runner_config
=
moe_runner_config
,
)
else
:
return
self
.
triton_kernel_moe_forward
(
...
...
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
else
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
python/sglang/srt/layers/quantization/unquant.py
View file @
29589512
from
__future__
import
annotations
import
importlib
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
importlib
.util
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.
ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.
moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
...
...
@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
kwargs
=
{}
if
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
activation_alpha
if
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
swiglu_limit
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
topk_output
=
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
**
kwargs
,
moe_runner_config
=
moe_runner_config
,
)
def
forward_cuda
(
...
...
@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
assert
self
.
triton_kernel_moe_with_bias_forward
is
not
None
return
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
moe_runner_config
=
moe_runner_config
,
w1_pcg
=
None
,
w2_pcg
=
None
,
)
else
:
assert
self
.
triton_kernel_moe_forward
is
not
None
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
else
:
if
_use_aiter
:
assert
not
no_combine
,
"unsupported"
assert
not
moe_runner_config
.
no_combine
,
"unsupported"
topk_weights
,
topk_ids
,
_
=
topk_output
if
apply_router_weight_on_input
:
if
moe_runner_config
.
apply_router_weight_on_input
:
assert
(
topk_weights
.
dim
()
==
2
),
"`topk_weights` should be in shape (num_tokens, topk)"
...
...
@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
if
moe_runner_config
.
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
...
...
@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
b2
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
topk_output
=
topk_output
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
moe_runner_config
=
moe_runner_config
,
)
def
forward_cpu
(
...
...
@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
if
use_intel_amx_backend
(
layer
)
and
not
apply_router_weight_on_input
:
assert
(
moe_runner_config
.
activation
==
"silu"
),
f
"activation =
{
moe_runner_config
.
activation
}
is not supported."
if
(
use_intel_amx_backend
(
layer
)
and
not
moe_runner_config
.
apply_router_weight_on_input
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
...
...
@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
,
x
,
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
moe_runner_config
,
)
def
forward_npu
(
...
...
@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
...
...
@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
,
x
,
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
moe_runner_config
,
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
29589512
...
...
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
,
TopKOutput
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
EPMoE
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
**
kwargs
,
topk_output
:
StandardTopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
# TODO(ch-wan): move it out of this class
...
...
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
)
if
routed_scaling_factor
is
not
None
:
output
*=
routed_scaling_factor
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
return
output
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
29589512
...
...
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
_is_fp8_fnuz
=
is_fp8_fnuz
()
...
...
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
topk_output
:
StandardTopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
activation
=
activation
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
29589512
...
...
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
_is_cuda
=
is_cuda
()
...
...
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
...
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
...
...
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
moe_runner_config
=
moe_runner_config
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
...
...
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer
,
x
,
topk_output
:
TopKOutput
,
**
kwargs
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
,
_
=
topk_output
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
29589512
...
...
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.moe
import
is_tbo_enabled
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
...
...
@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device"
,
"disable_chunked_prefix_cache"
,
"disable_radix_cache"
,
"enable_two_batch_overlap"
,
"tbo_token_distribution_threshold"
,
"enable_dp_lm_head"
,
"moe_a2a_backend"
,
"deepep_mode"
,
"enable_flashinfer_cutlass_moe"
,
"enable_flashinfer_trtllm_moe"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"deepep_config"
,
"ep_num_redundant_experts"
,
"enable_nan_detection"
,
"flashinfer_mla_disable_ragged"
,
...
...
@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
"weight_loader_disable_mmap"
,
"enable_triton_kernel_moe"
,
"enable_flashinfer_mxfp4_moe"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
29589512
...
...
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe
.utils
import
DeepEPMode
,
MoeA2ABackend
from
sglang.srt.layers.moe
import
initialize_moe_config
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
CloseSessionReqInput
,
...
...
@@ -245,6 +245,9 @@ class Scheduler(
)
)
# Init model config
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
self
.
idle_sleeper
=
None
...
...
@@ -292,6 +295,9 @@ class Scheduler(
# Init tokenizer
self
.
init_tokenizer
()
# Init moe config
self
.
init_moe_config
()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if
self
.
server_args
.
reasoning_parser
and
self
.
tokenizer
:
reasoning_parser
=
ReasoningParser
(
...
...
@@ -538,8 +544,6 @@ class Scheduler(
def
init_tokenizer
(
self
):
server_args
=
self
.
server_args
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
is_generation
=
self
.
model_config
.
is_generation
if
server_args
.
skip_tokenizer_init
:
...
...
@@ -761,6 +765,10 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending
self
.
disagg_prefill_inflight_queue
:
List
[
Req
]
=
[]
def
init_moe_config
(
self
):
if
hasattr
(
self
.
model_config
.
hf_config
,
"num_experts_per_tok"
):
initialize_moe_config
(
self
.
server_args
)
@
DynamicGradMode
()
def
event_loop_normal
(
self
):
"""A normal scheduler loop."""
...
...
@@ -1823,11 +1831,6 @@ class Scheduler(
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
spec_algorithm
=
self
.
spec_algorithm
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
enable_two_batch_overlap
=
self
.
server_args
.
enable_two_batch_overlap
,
enable_deepep_moe
=
MoeA2ABackend
(
self
.
server_args
.
moe_a2a_backend
).
is_deepep
(),
deepep_mode
=
DeepEPMode
(
self
.
server_args
.
deepep_mode
),
require_mlp_tp_gather
=
require_mlp_tp_gather
(
self
.
server_args
),
disable_overlap_schedule
=
self
.
server_args
.
disable_overlap_schedule
,
)
...
...
@@ -1922,9 +1925,6 @@ class Scheduler(
disable_cuda_graph
:
bool
,
spec_algorithm
,
speculative_num_draft_tokens
,
enable_two_batch_overlap
:
bool
,
enable_deepep_moe
:
bool
,
deepep_mode
:
DeepEPMode
,
require_mlp_tp_gather
:
bool
,
disable_overlap_schedule
:
bool
,
):
...
...
@@ -1972,9 +1972,6 @@ class Scheduler(
is_extend_in_batch
,
*
tbo_preparer
.
prepare_all_gather
(
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
,
),
],
dtype
=
torch
.
int64
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
29589512
...
...
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
MoeA2ABackend
from
sglang.srt.layers.quantization
import
(
deep_gemm_wrapper
,
monkey_patch_isinstance_for_vllm_base_layer
,
...
...
@@ -219,8 +218,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend"
:
self
.
use_mla_backend
,
"speculative_algorithm"
:
self
.
spec_algorithm
,
"moe_a2a_backend"
:
MoeA2ABackend
(
server_args
.
moe_a2a_backend
),
"deepep_mode"
:
DeepEPMode
(
server_args
.
deepep_mode
),
}
)
...
...
python/sglang/srt/models/dbrx.py
View file @
29589512
...
...
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
self
.
params_dtype
=
params_dtype
self
.
router
=
DbrxRouter
(
config
,
self
.
params_dtype
)
self
.
topk
=
TopK
(
self
.
top_k
,
renormalize
=
True
,
)
self
.
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
True
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
...
...
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
d_model
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
router
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
,
topk_output
,
self
.
moe_runner_config
,
)
if
self
.
tp_size
>
1
:
...
...
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
residual
=
hidden_states
hidden_states
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
...
...
python/sglang/srt/models/deepseek.py
View file @
29589512
...
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
w1
=
self
.
w1
,
w2
=
self
.
w2
,
topk_output
=
topk_output
,
inplace
=
True
,
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
True
)
,
)
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment