Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
15ad6c90
Unverified
Commit
15ad6c90
authored
Jul 19, 2025
by
Cheng Wan
Committed by
GitHub
Jul 19, 2025
Browse files
[1/N] MoE Refactor: refactor `select_experts` (#7966)
parent
cfab0ff6
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
395 additions
and
762 deletions
+395
-762
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+5
-2
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+13
-74
python/sglang/srt/layers/moe/fused_moe_native.py
python/sglang/srt/layers/moe/fused_moe_native.py
+7
-47
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+7
-38
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+6
-29
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+171
-5
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+9
-23
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+8
-31
python/sglang/srt/layers/quantization/base_config.py
python/sglang/srt/layers/quantization/base_config.py
+14
-7
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+7
-28
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+21
-71
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+12
-40
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+8
-27
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+11
-52
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+8
-26
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+55
-152
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+9
-28
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+14
-75
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+9
-6
No files found.
python/sglang/srt/custom_op.py
View file @
15ad6c90
...
...
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
self
.
_original_forward_method
=
self
.
_forward_method
# NOTE: Temporarily workaround MoE
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs=1
if
"FusedMoE"
in
self
.
__class__
.
__name__
:
if
num_tokens
==
1
:
from
sglang.srt.layers.moe.fused_moe_native
import
(
fused_moe_forward_native
,
)
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self
.
_forward_method
=
fused_moe_forward_native
elif
"TopK"
in
self
.
__class__
.
__name__
:
if
num_tokens
==
1
:
self
.
_forward_method
=
self
.
forward_native
else
:
self
.
_forward_method
=
self
.
forward_native
self
.
is_torch_compile
=
True
...
...
python/sglang/srt/layers/linear.py
View file @
15ad6c90
...
...
@@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
"
QuantizationConfig
"
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
tp_rank
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
15ad6c90
import
logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
einops
import
torch
from
torch.nn
import
Module
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
ep_gather
,
ep_scatter
,
...
...
@@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module):
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_per_token_if_dynamic
:
bool
=
True
,
...
...
@@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module):
self
.
layer_id
=
layer_id
self
.
num_experts
=
num_experts
assert
self
.
num_experts
%
self
.
tp_size
==
0
assert
(
num_fused_shared_experts
==
0
),
"num_fused_shared_experts is not supported in EP"
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
num_experts_per_partition
,
self
.
expert_map
=
self
.
determine_expert_map
()
self
.
start_expert_id
=
self
.
tp_rank
*
self
.
num_experts_per_partition
self
.
end_expert_id
=
self
.
start_expert_id
+
self
.
num_experts_per_partition
-
1
self
.
top_k
=
top_k
self
.
intermediate_size
=
intermediate_size
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
custom_routing_function
=
custom_routing_function
self
.
activation
=
activation
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
use_per_token_if_dynamic
=
use_per_token_if_dynamic
...
...
@@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module):
)
return
(
local_num_experts
,
expert_map
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm
(
hidden_states
,
router_logits
)
return
self
.
forward_deepgemm
(
hidden_states
,
topk_output
)
else
:
return
self
.
forward_normal
(
hidden_states
,
router_logits
)
return
self
.
forward_normal
(
hidden_states
,
topk_output
)
def
forward_deepgemm
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
):
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
correction_bias
=
self
.
correction_bias
,
custom_routing_function
=
self
.
custom_routing_function
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
_
=
topk_output
if
not
self
.
use_block_quant
:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
...
...
@@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module):
)
return
output
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
assert
self
.
quant_method
is
not
None
topk_weights
,
topk_ids
,
_
=
topk_output
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
...
...
@@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module):
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
)
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
correction_bias
=
self
.
correction_bias
,
custom_routing_function
=
self
.
custom_routing_function
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
if
self
.
use_w4afp8
:
local_topk_ids
=
topk_ids
if
self
.
expert_map
is
not
None
:
...
...
@@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE):
intermediate_size
:
int
,
layer_id
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
...
...
@@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE):
intermediate_size
=
intermediate_size
,
layer_id
=
layer_id
,
params_dtype
=
params_dtype
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
topk_group
=
topk_group
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
prefix
,
correction_bias
=
correction_bias
,
custom_routing_function
=
custom_routing_function
,
activation
=
activation
,
routed_scaling_factor
=
routed_scaling_factor
,
)
...
...
python/sglang/srt/layers/moe/fused_moe_native.py
View file @
15ad6c90
...
...
@@ -9,21 +9,14 @@ import torch
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
GeluAndMul
,
SiluAndMul
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
TopKOutput
def
fused_moe_forward_native
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -34,20 +27,7 @@ def fused_moe_forward_native(
if
apply_router_weight_on_input
:
raise
NotImplementedError
()
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
torch_native
=
True
,
)
topk_weights
,
topk_ids
,
_
=
topk_output
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
...
...
@@ -67,15 +47,8 @@ def fused_moe_forward_native(
def
moe_forward_native
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -86,20 +59,7 @@ def moe_forward_native(
if
apply_router_weight_on_input
:
raise
NotImplementedError
()
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
torch_native
=
True
,
routed_scaling_factor
=
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
_
=
topk_output
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
len_experts
=
layer
.
num_experts
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
15ad6c90
...
...
@@ -6,13 +6,13 @@ import functools
import
json
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
scaled_fp8_quant
,
...
...
@@ -1328,8 +1328,7 @@ def fused_experts(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
...
...
@@ -1348,7 +1347,7 @@ def fused_experts(
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
topk_weights
,
topk_ids
,
_
=
topk_output
if
inplace
:
assert
not
no_combine
,
"no combine + inplace makes no sense"
torch
.
ops
.
sglang
.
inplace_fused_experts
(
...
...
@@ -1732,17 +1731,10 @@ def fused_moe(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
topk_output
:
TopKOutput
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1766,16 +1758,9 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- topk_output (TopKOutput): The top-k output of the experts.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseek V2/V3/R1 series models use grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
...
...
@@ -1799,28 +1784,12 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
gating_output
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
topk
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
15ad6c90
...
...
@@ -2,7 +2,7 @@
import
logging
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
...
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
...
...
@@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module):
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
top_k
:
Optional
[
int
]
=
None
,
layer_id
:
Optional
[
int
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_presharded_weights
:
bool
=
False
,
...
...
@@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
()
...
...
@@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module):
self
.
ep_rank
=
0
self
.
local_num_experts
=
num_experts
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
top_k
=
top_k
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
activation
=
activation
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
use_presharded_weights
=
use_presharded_weights
...
...
@@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module):
)
return
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
assert
self
.
quant_method
is
not
None
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
topk_output
=
topk_output
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
...
...
python/sglang/srt/layers/moe/topk.py
View file @
15ad6c90
...
...
@@ -12,12 +12,15 @@
# limitations under the License.
# ==============================================================================
from
__future__
import
annotations
import
math
from
typing
import
Callab
le
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
NamedTup
le
,
Optional
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.eplb
import
expert_location_dispatch
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location_dispatch
import
(
...
...
@@ -52,6 +55,168 @@ if _use_aiter:
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
if
_is_npu
:
import
torch_npu
class
TopKOutput
(
NamedTuple
):
topk_weights
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
router_logits
:
torch
.
Tensor
class
TopK
(
CustomOp
):
# TODO(ch-wan): support triton_kernels
def
__init__
(
self
,
top_k
:
int
,
*
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
renormalize
:
bool
=
True
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super
().
__init__
()
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
top_k
=
top_k
self
.
use_grouped_topk
=
use_grouped_topk
self
.
renormalize
=
renormalize
self
.
topk_group
=
topk_group
self
.
num_expert_group
=
num_expert_group
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
routed_scaling_factor
=
routed_scaling_factor
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
torch_native
=
True
return
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
torch_native
=
torch_native
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
torch_native
=
False
return
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
torch_native
=
torch_native
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
def
forward_cpu
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
return
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
def
forward_npu
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
global_num_experts
=
router_logits
.
shape
[
-
1
]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if
global_num_experts
==
256
:
return
torch_npu
.
npu_moe_gating_top_k
(
router_logits
,
k
=
self
.
top_k
,
bias
=
self
.
correction_bias
,
k_group
=
self
.
topk_group
,
group_count
=
self
.
num_expert_group
,
group_select_mode
=
1
,
renorm
=
0
,
norm_type
=
1
,
routed_scaling_factor
=
1
,
eps
=
float
(
1e-20
),
)
else
:
torch_native
=
True
return
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
self
.
use_grouped_topk
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
custom_routing_function
=
self
.
custom_routing_function
,
correction_bias
=
self
.
correction_bias
,
torch_native
=
torch_native
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
def
fused_topk_torch_native
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -436,8 +601,9 @@ def select_experts(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
use_grouped_topk
:
bool
,
renormalize
:
bool
,
*
,
use_grouped_topk
:
bool
=
False
,
renormalize
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
...
...
@@ -447,7 +613,7 @@ def select_experts(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
):
)
->
TopKOutput
:
router_logits
,
correction_bias
=
(
expert_location_dispatch
.
transform_select_experts_inputs
(
router_logits
=
router_logits
,
...
...
@@ -522,4 +688,4 @@ def select_experts(
get_global_expert_distribution_recorder
().
on_select_experts
(
topk_ids
=
topk_ids
)
return
topk_weights
,
topk_ids
return
TopKOutput
(
topk_weights
,
topk_ids
,
router_logits
)
python/sglang/srt/layers/quantization/__init__.py
View file @
15ad6c90
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from
__future__
import
annotations
import
builtins
import
inspect
from
typing
import
Callable
,
Dict
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Type
,
Union
import
torch
...
...
@@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"fp8"
:
Fp8Config
,
...
...
@@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"self"
:
self
,
"layer"
:
layer
,
"x"
:
x
,
"router_logits"
:
router_logits
,
"top_k"
:
top_k
,
"renormalize"
:
renormalize
,
"use_grouped_topk"
:
use_grouped_topk
,
"topk_group"
:
topk_group
,
"num_expert_group"
:
num_expert_group
,
"custom_routing_function"
:
custom_routing_function
,
"topk_output"
:
topk_output
,
}
if
correction_bias
is
not
None
:
if
not
has_correction_bias
:
raise
ValueError
(
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
)
kwargs
[
"e_score_correction_bias"
]
=
correction_bias
return
original_apply
(
**
kwargs
)
setattr
(
class_obj
,
"apply"
,
new_apply
)
...
...
python/sglang/srt/layers/quantization/awq.py
View file @
15ad6c90
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
logging
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
replace_parameter
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
try
:
from
vllm
import
_custom_ops
as
ops
...
...
@@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
scoring_func
==
"softmax"
),
"Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
top_k
=
top_k
,
use_grouped_topk
=
use_grouped_topk
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
return
fused_marlin_moe
(
x
,
...
...
python/sglang/srt/layers/quantization/base_config.py
View file @
15ad6c90
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
from
__future__
import
annotations
import
inspect
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Type
import
torch
from
torch
import
nn
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
class
QuantizeMethodBase
(
ABC
):
"""Base class for different quantized methods."""
...
...
@@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
()
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
class
QuantizationConfig
(
ABC
):
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
15ad6c90
...
...
@@ -3,7 +3,7 @@
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
...
...
@@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
# Expert fusion with INT8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
15ad6c90
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
enum
import
logging
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.utils
import
(
...
...
@@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import (
)
from
sglang.srt.utils
import
is_cpu
,
is_cuda
,
is_npu
,
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
CompressedTensorsConfig
,
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
...
...
@@ -51,7 +59,7 @@ __all__ = [
]
class
CompressedTensorsMoEMethod
:
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
)
:
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
is
CompressedTensorsMoEMethod
:
return
super
().
__new__
(
cls
)
...
...
@@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod:
@
staticmethod
def
get_moe_method
(
quant_config
:
"
CompressedTensorsConfig
"
,
# type: ignore # noqa E501
quant_config
:
CompressedTensorsConfig
,
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
...
...
@@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod:
class
CompressedTensorsW8A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
...
...
@@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
...
...
@@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
...
...
@@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
15ad6c90
...
...
@@ -3,7 +3,7 @@
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
_is_hip
=
is_hip
()
...
...
@@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
...
...
@@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret
=
self
.
maybe_apply_hip_fused_experts
(
layer
,
x
,
topk_weights
,
topk_ids
,
topk_output
,
activation
,
no_combine
,
)
...
...
@@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
topk_weights
,
topk_ids
,
_
=
topk_output
return
cutlass_fused_experts_fp8
(
x
,
layer
.
w13_weight
.
transpose
(
1
,
2
),
...
...
@@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
@@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
activation
:
str
=
"silu"
,
no_combine
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
topk_weights
,
topk_ids
,
_
=
topk_output
if
_use_hip_int4
:
# TODO: add triton kernel and add check _use_aiter
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
...
...
@@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
15ad6c90
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
logging
from
dataclasses
import
dataclass
from
fractions
import
Fraction
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
...
...
@@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import (
unpack_cols
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
try
:
from
vllm
import
_custom_ops
as
ops
except
ImportError
:
...
...
@@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
scoring_func
==
"softmax"
),
"Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
e_score_correction_bias
,
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
return
fused_marlin_moe
(
x
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
15ad6c90
...
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
is_cuda
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
if
is_cuda
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
...
...
@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
...
...
@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
from
sglang.srt.layers.moe.topk
import
select_experts
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
if
self
.
enable_flashinfer_moe
:
assert
(
...
...
@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights
,
topk_ids
,
_
=
topk_output
output
=
flashinfer_cutlass_fused_moe
(
x
,
topk_ids
.
to
(
torch
.
int
),
...
...
@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
topk_weights
,
topk_ids
,
_
=
topk_output
return
cutlass_moe_fp4
(
a
=
x
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
...
...
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
15ad6c90
...
...
@@ -2,8 +2,9 @@
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
numpy
as
np
import
torch
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
...
...
@@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
...
...
@@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
# avoid circular import
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
top_k
=
top_k
,
use_grouped_topk
=
use_grouped_topk
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
...
...
@@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int4_w4a16
=
weight_bits
==
4
,
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
15ad6c90
from
__future__
import
annotations
import
importlib
from
typing
import
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -21,6 +23,9 @@ from sglang.srt.utils import (
use_intel_amx_backend
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
...
...
@@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
if
torch
.
cuda
.
is_available
():
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
)
else
:
triton_kernel_moe_forward
=
None
else
:
fused_experts
=
None
# type: ignore
triton_kernel_moe_forward
=
None
self
.
moe_forward_native
=
moe_forward_native
self
.
fused_experts
=
fused_experts
self
.
triton_kernel_moe_forward
=
triton_kernel_moe_forward
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
...
...
@@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
)
# TODO(ch-wan): re-enable the Triton kernel
raise
NotImplementedError
(
"The Triton kernel is temporarily disabled."
)
# return triton_kernel_moe_forward(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# gating_output=router_logits,
# topk=top_k,
# renormalize=renormalize,
# )
else
:
from
sglang.srt.layers.moe.topk
import
select_experts
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
if
_use_aiter
:
assert
not
no_combine
,
"unsupported"
topk_weights
,
topk_ids
,
_
=
topk_output
if
apply_router_weight_on_input
:
assert
(
topk_weights
.
dim
()
==
2
...
...
@@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
)
# topk_weights must be FP32 (float32)
return
fused_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
),
)
else
:
return
self
.
fused_experts
(
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_experts
,
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
and
not
no_combine
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
@@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
f
"activation =
{
activation
}
is not supported."
if
use_intel_amx_backend
(
layer
):
if
use_intel_amx_backend
(
layer
)
and
not
apply_router_weight_on_input
:
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
from
sglang.srt.layers.moe.topk
import
(
apply_topk_weights_cpu
,
select_experts
,
)
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
layer
.
w13_weight
,
...
...
@@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
True
,
# is_vnni
)
else
:
return
self
.
moe_forward_native
(
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
return
moe_forward_native
(
layer
,
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
num_fused_shared_experts
,
custom_routing_function
,
correction_bias
,
activation
,
apply_router_weight_on_input
,
inplace
,
no_combine
,
routed_scaling_factor
,
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
forward_npu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
moe_forward_native
(
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
return
moe_forward_native
(
layer
,
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
num_fused_shared_experts
,
custom_routing_function
,
correction_bias
,
activation
,
apply_router_weight_on_input
,
inplace
,
no_combine
,
routed_scaling_factor
,
topk_output
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
...
...
@@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
15ad6c90
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
_is_fp8_fnuz
=
is_fp8_fnuz
()
...
...
@@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
15ad6c90
...
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
importlib
import
sys
from
types
import
MappingProxyType
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
,
cast
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -37,6 +37,9 @@ from sglang.srt.utils import (
use_intel_amx_backend
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
...
...
@@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig):
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
_is_npu
:
...
...
@@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
...
...
@@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
apply_router_weight_on_input
,
topk_weights
,
x
)
...
...
@@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_output
=
topk_output
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
@@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
List
[
int
]
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
...
...
@@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
self
,
layer
,
x
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
topk_group
,
num_expert_group
,
num_fused_shared_experts
,
custom_routing_function
,
correction_bias
,
activation
,
apply_router_weight_on_input
,
routed_scaling_factor
,
topk_output
:
TopKOutput
,
**
kwargs
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.topk
import
select_experts
global_num_experts
=
router_logits
.
shape
[
-
1
]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if
global_num_experts
==
256
:
topk_weights
,
topk_ids
,
_
=
torch_npu
.
npu_moe_gating_top_k
(
router_logits
,
k
=
top_k
,
bias
=
correction_bias
,
k_group
=
topk_group
,
group_count
=
num_expert_group
,
group_select_mode
=
1
,
renorm
=
0
,
norm_type
=
1
,
routed_scaling_factor
=
1
,
eps
=
float
(
1e-20
),
)
else
:
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
torch_native
=
True
,
routed_scaling_factor
=
routed_scaling_factor
,
)
topk_weights
,
topk_ids
,
_
=
topk_output
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_weights
=
topk_weights
.
to
(
x
.
dtype
)
return
npu_fused_experts
(
...
...
@@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
w2_scale
=
layer
.
w2_weight_scale
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
top_k
=
top
_
k
,
top_k
=
topk
_ids
.
shape
[
1
]
,
)
python/sglang/srt/models/deepseek.py
View file @
15ad6c90
...
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
topk
=
TopK
(
top_k
=
self
.
top_k
,
renormalize
=
config
.
norm_topk_prob
,
)
self
.
experts
=
nn
.
ModuleList
(
[
DeepseekMLP
(
...
...
@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
fused_moe
.
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
w1
=
self
.
w1
,
w2
=
self
.
w2
,
topk_output
=
topk_output
,
inplace
=
True
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment