Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3fa62da7
Unverified
Commit
3fa62da7
authored
Sep 05, 2025
by
Cheng Wan
Committed by
GitHub
Sep 05, 2025
Browse files
[7/N] MoE Refactor: the implementation of new framework (#9269)
parent
dbb1235d
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
489 additions
and
261 deletions
+489
-261
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+25
-17
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+60
-35
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+21
-18
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+64
-40
python/sglang/srt/layers/quantization/quark/quark_moe.py
python/sglang/srt/layers/quantization/quark/quark_moe.py
+32
-27
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+67
-43
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+26
-17
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+35
-20
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+71
-31
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/model_loader/__init__.py
python/sglang/srt/model_loader/__init__.py
+9
-3
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+18
-4
python/sglang/test/test_cutlass_moe.py
python/sglang/test/test_cutlass_moe.py
+24
-5
test/srt/test_mla_deepseek_v3.py
test/srt/test_mla_deepseek_v3.py
+37
-0
No files found.
python/sglang/srt/layers/quantization/gptq.py
View file @
3fa62da7
...
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.token_dispatcher
import
(
StandardDispatchOutput
,
CombineInput
,
)
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
...
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.linear
import
set_weight_attrs
from
sglang.srt.layers.linear
import
set_weight_attrs
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
intermediate_size
=
extra_weight_attrs
.
pop
(
"intermediate_size"
)
self
.
is_k_full
=
(
not
self
.
quant_config
.
desc_act
)
or
layer
.
moe_tp_size
==
1
self
.
is_k_full
=
(
not
self
.
quant_config
.
desc_act
)
or
(
intermediate_size_per_partition
==
intermediate_size
)
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
w2_scales_size
=
(
if
self
.
quant_config
.
desc_act
:
intermediate_size
w2_scales_size
=
intermediate_size_per_partition
if
self
.
quant_config
.
desc_act
else
:
else
intermediate_size_per_partition
w2_scales_size
=
intermediate_size_per_partition
*
layer
.
moe_tp_size
)
scales_size2
=
w2_scales_size
//
self
.
quant_config
.
group_size
scales_size2
=
w2_scales_size
//
self
.
quant_config
.
group_size
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
else
:
...
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
# Delay the import to avoid circular dependency
# Delay the import to avoid circular dependency
assert
(
assert
(
moe_runner_config
.
activation
==
"silu"
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
),
"Only SiLU activation is supported."
# The input must currently be float16
# The input must currently be float16
...
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights
,
topk_ids
,
router_logits
=
topk_output
topk_weights
,
topk_ids
,
router_logits
=
topk_output
return
fused_marlin_moe
(
output
=
fused_marlin_moe
(
x
,
x
,
layer
.
w13_qweight
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w2_qweight
,
...
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_bits
=
self
.
quant_config
.
weight_bits
,
num_bits
=
self
.
quant_config
.
weight_bits
,
is_k_full
=
self
.
is_k_full
,
is_k_full
=
self
.
is_k_full
,
).
to
(
orig_dtype
)
).
to
(
orig_dtype
)
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
3fa62da7
...
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
...
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.dp_attention
import
get_dp_global_num_tokens
,
get_local_dp_buffer
from
sglang.srt.layers.dp_attention
import
get_dp_global_num_tokens
,
get_local_dp_buffer
from
sglang.srt.layers.moe
import
(
from
sglang.srt.layers.moe
import
(
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
)
)
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
...
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.topk
import
TopKOutput
CombineInput
,
StandardDispatchOutput
,
)
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
scaled_fp4_quant
from
sgl_kernel
import
scaled_fp4_quant
...
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
...
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w13_weight
=
ModelWeightParameter
(
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
weight_dtype
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
weight_dtype
,
),
),
input_dim
=
2
,
input_dim
=
2
,
output_dim
=
1
,
output_dim
=
1
,
...
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w2_weight
=
ModelWeightParameter
(
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
weight_dtype
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
weight_dtype
,
),
),
input_dim
=
2
,
input_dim
=
2
,
output_dim
=
1
,
output_dim
=
1
,
...
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# w13_weight has shape (num_experts, 2 * intermediate_size
_per_partition
, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
# where the first intermediate_size
_per_partition
rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
intermediate_size
_per_partition
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
start
:
start
+
intermediate_size
_per_partition
,
:
],
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
)
# Requantize using the combined max scale
# Requantize using the combined max scale
(
(
layer
.
w13_weight
[
expert_id
][
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
start
:
start
+
intermediate_size
_per_partition
,
:
],
],
_
,
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
start
+=
intermediate_size
_per_partition
# Update the scale parameter to be per-expert instead of per-shard
# Update the scale parameter to be per-expert instead of per-shard
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
...
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
quant_info
=
TritonMoeQuantInfo
(
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
w13_weight
=
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
# ModelOpt uses per-tensor quantization
per_channel_quant
=
False
,
w1_scale
=
layer
.
w13_weight_scale
,
w1
3
_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1
3
_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
ModelOptFp4Config
(
QuantizationConfig
):
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
"""Config class for FP4."""
...
@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return
self
.
enable_flashinfer_cutlass_moe
return
self
.
enable_flashinfer_cutlass_moe
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
:
FusedMoE
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
assert
(
assert
(
moe_runner_config
.
activation
==
"silu"
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
),
"Only SiLU activation is supported."
moe_runner_config
=
self
.
moe_runner_config
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return
layer
.
forward
(
x
,
topk_output
)
return
StandardCombineInput
(
hidden_states
=
layer
.
forward
(
x
,
topk_output
)
)
if
self
.
enable_flashinfer_cutlass_moe
:
if
self
.
enable_flashinfer_cutlass_moe
:
assert
(
assert
(
...
@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank
=
layer
.
moe_tp_rank
,
tp_rank
=
layer
.
moe_tp_rank
,
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
)[
0
]
)[
0
]
# Scale by routed_scaling_factor is fused into select_experts.
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
output
,
global_output
=
get_local_dp_buffer
(),
output
output
,
global_output
=
get_local_dp_buffer
(),
output
get_tp_group
().
reduce_scatterv
(
get_tp_group
().
reduce_scatterv
(
global_output
,
output
=
output
,
sizes
=
get_dp_global_num_tokens
()
global_output
,
output
=
output
,
sizes
=
get_dp_global_num_tokens
()
)
)
return
output
return
StandardCombineInput
(
hidden_states
=
output
)
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
...
@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
).
to
(
x
.
dtype
)
# Scale by routed_scaling_factor is fused into select_experts.
# Scale by routed_scaling_factor is fused into select_experts.
return
output
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
3fa62da7
...
@@ -9,6 +9,8 @@ import torch
...
@@ -9,6 +9,8 @@ import torch
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
...
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.topk
import
TopKOutput
CombineInput
,
StandardDispatchOutput
,
)
def
get_weight_perm
(
num_bits
:
int
):
def
get_weight_perm
(
num_bits
:
int
):
...
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer
.
register_parameter
(
key
,
param
)
layer
.
register_parameter
(
key
,
param
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
# avoid circular import
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
assert
(
assert
(
moe_runner_config
.
activation
==
"silu"
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
),
"Only SiLU activation is supported."
weight_bits
=
self
.
quant_config
.
weight_bits
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
quant_info
=
TritonMoeQuantInfo
(
x
,
w13_weight
=
layer
.
w13_qweight
,
layer
.
w13_qweight
,
w2_weight
=
layer
.
w2_qweight
,
layer
.
w2_qweight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
layer
.
w13_scales
,
w1
3
_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w2_scale
=
layer
.
w2_scales
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w1
3
_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
],
block_shape
=
[
0
,
layer
.
group_size
],
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
@
staticmethod
@
staticmethod
def
get_weight_loader
(
layer
,
weight_loader
):
def
get_weight_loader
(
layer
,
weight_loader
):
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
3fa62da7
...
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.moe.utils
import
get_moe_runner_backend
from
sglang.srt.layers.moe.utils
import
get_moe_runner_backend
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -59,8 +61,10 @@ if is_flashinfer_available():
...
@@ -59,8 +61,10 @@ if is_flashinfer_available():
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.topk
import
TopKOutput
CombineInput
,
StandardDispatchOutput
,
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
with_bias
:
bool
=
False
,
with_bias
:
bool
=
False
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
...
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# for to hold non-uniform sharded tensor as well as swizzling
intermediate_size_per_partition_after_pad
=
intermediate_size
intermediate_size_per_partition_after_pad
=
intermediate_size
_per_partition
if
_is_sm100_supported
:
if
_is_sm100_supported
:
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
256
intermediate_size
_per_partition
,
256
)
)
hidden_size
=
round_up
(
hidden_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
else
:
else
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
64
intermediate_size
_per_partition
,
64
)
)
elif
has_triton_kernels
:
elif
has_triton_kernels
:
# TODO: this is a hack to make
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
mxfp4_block
intermediate_size
_per_partition
,
mxfp4_block
)
)
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
self
.
intermediate_size
_per_partition
=
intermediate_size_per_partition_after_pad
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
...
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert
(
assert
(
layer
.
w13_weight
.
dim
()
==
3
layer
.
w13_weight
.
dim
()
==
3
and
layer
.
w13_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight
.
shape
[
1
]
==
self
.
intermediate_size_per_partition
*
2
and
layer
.
w13_weight
.
shape
[
2
]
==
self
.
hidden_size
//
2
and
layer
.
w13_weight
.
shape
[
2
]
==
self
.
hidden_size
//
2
)
)
assert
(
assert
(
layer
.
w13_weight_scale
.
dim
()
==
3
layer
.
w13_weight_scale
.
dim
()
==
3
and
layer
.
w13_weight_scale
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_scale
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_scale
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight_scale
.
shape
[
1
]
==
self
.
intermediate_size_per_partition
*
2
and
layer
.
w13_weight_scale
.
shape
[
2
]
==
self
.
hidden_size
//
sf_block_size
and
layer
.
w13_weight_scale
.
shape
[
2
]
==
self
.
hidden_size
//
sf_block_size
)
)
assert
(
assert
(
layer
.
w2_weight
.
dim
()
==
3
layer
.
w2_weight
.
dim
()
==
3
and
layer
.
w2_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight
.
shape
[
2
]
==
self
.
intermediate_size
//
2
and
layer
.
w2_weight
.
shape
[
2
]
==
self
.
intermediate_size_per_partition
//
2
)
)
assert
(
assert
(
layer
.
w2_weight_scale
.
dim
()
==
3
layer
.
w2_weight_scale
.
dim
()
==
3
and
layer
.
w2_weight_scale
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight_scale
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight_scale
.
shape
[
2
]
and
layer
.
w2_weight_scale
.
shape
[
2
]
==
self
.
intermediate_size
//
sf_block_size
==
self
.
intermediate_size
_per_partition
//
sf_block_size
)
)
assert
(
assert
(
layer
.
w13_weight_bias
.
dim
()
==
2
layer
.
w13_weight_bias
.
dim
()
==
2
and
layer
.
w13_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_bias
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight_bias
.
shape
[
1
]
==
self
.
intermediate_size_per_partition
*
2
)
)
assert
(
assert
(
layer
.
w2_weight_bias
.
dim
()
==
2
layer
.
w2_weight_bias
.
dim
()
==
2
...
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch
.
stack
(
gemm1_scales_mxfp4_shuffled
)
torch
.
stack
(
gemm1_scales_mxfp4_shuffled
)
.
reshape
(
.
reshape
(
self
.
num_experts
,
self
.
num_experts
,
2
*
self
.
intermediate_size
,
2
*
self
.
intermediate_size
_per_partition
,
self
.
hidden_size
//
sf_block_size
,
self
.
hidden_size
//
sf_block_size
,
)
)
.
view
(
torch
.
float8_e4m3fn
)
.
view
(
torch
.
float8_e4m3fn
)
...
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.
reshape
(
.
reshape
(
self
.
num_experts
,
self
.
num_experts
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
intermediate_size
//
sf_block_size
,
self
.
intermediate_size
_per_partition
//
sf_block_size
,
)
)
.
view
(
torch
.
float8_e4m3fn
)
.
view
(
torch
.
float8_e4m3fn
)
)
)
...
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return
tile_tokens_dim
return
tile_tokens_dim
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
moe_runner_config
=
self
.
moe_runner_config
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
# When bf16 mode is enabled, we don't need to quantize the input,
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
...
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k
,
top_k
,
None
,
# n_group # TODO: support n_group
None
,
# n_group # TODO: support n_group
None
,
# topk_group # TODO: support topk_group
None
,
# topk_group # TODO: support topk_group
self
.
intermediate_size
,
# padded to multiple of 256
self
.
intermediate_size
_per_partition
,
# padded to multiple of 256
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
num_local_experts
,
# local num experts
layer
.
num_local_experts
,
# local num experts
None
,
None
,
...
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
1
,
# routing_method_type, renormalize
1
,
# routing_method_type, renormalize
True
,
# do finalize
True
,
# do finalize
)[
0
]
)[
0
]
return
trtllm_gen_output
return
StandardCombineInput
(
hidden_states
=
trtllm_gen_output
)
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
assert
(
assert
(
layer
.
moe_ep_size
==
1
layer
.
moe_ep_size
==
1
),
"Expert parallel is not supported when using triton kernels"
),
"Expert parallel is not supported when using triton kernels"
if
self
.
with_bias
:
if
self
.
with_bias
:
return
self
.
triton_kernel_moe_with_bias_forward
(
output
=
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
self
.
w13_weight_triton_tensor
,
w1
=
self
.
w13_weight_triton_tensor
,
w1_pcg
=
self
.
w13_precision_config
,
w1_pcg
=
self
.
w13_precision_config
,
...
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
moe_runner_config
=
moe_runner_config
,
moe_runner_config
=
moe_runner_config
,
)
)
else
:
else
:
return
self
.
triton_kernel_moe_forward
(
output
=
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
moe_runner_config
=
moe_runner_config
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
else
:
else
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
quant_info
=
TritonMoeQuantInfo
(
w13_weight
=
layer
.
w13_weight
,
return
fused_experts
(
w2_weight
=
layer
.
w2_weight
,
hidden_states
=
x
,
w13_weight_bias
=
layer
.
w13_weight_bias
,
w1
=
layer
.
w13_weight
,
w2_weight_bias
=
layer
.
w2_weight_bias
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
Mxfp4DynamicQuantMoEMethod
(
FusedMoEMethodBase
):
class
Mxfp4DynamicQuantMoEMethod
(
FusedMoEMethodBase
):
...
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
...
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
return
w
,
mx_scales
return
w
,
mx_scales
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w13
,
w13_mx_scales
=
self
.
mxfp4_quantize
(
layer
.
w13_weight
.
data
)
w13
,
w13_mx_scales
=
self
.
mxfp4_quantize
(
layer
.
w13_weight
.
data
)
w2
,
w2_mx_scales
=
self
.
mxfp4_quantize
(
layer
.
w2_weight
.
data
)
w2
,
w2_mx_scales
=
self
.
mxfp4_quantize
(
layer
.
w2_weight
.
data
)
...
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
...
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_mx_scales
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_mx_scales
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
)
->
torch
.
Tensor
:
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
if
_is_hip
:
if
_is_hip
:
topk_weights
=
topk_weights
.
to
(
topk_weights
=
topk_weights
.
to
(
torch
.
float32
torch
.
float32
)
# aiter's moe_sorting requires topk_weights to be FP32
)
# aiter's moe_sorting requires topk_weights to be FP32
return
fused_moe
(
output
=
fused_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
...
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
activation
=
(
activation
=
(
ActivationType
.
Silu
ActivationType
.
Silu
if
moe_runner_config
.
activation
==
"silu"
if
self
.
moe_runner_config
.
activation
==
"silu"
else
ActivationType
.
Gelu
else
ActivationType
.
Gelu
),
),
doweight_stage1
=
False
,
doweight_stage1
=
False
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/layers/quantization/quark/quark_moe.py
View file @
3fa62da7
...
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
...
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
from
aiter.utility.fp4_utils
import
e8m0_shuffle
from
aiter.utility.fp4_utils
import
e8m0_shuffle
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.utils
import
get_bool_env_var
,
mxfp_supported
,
set_weight_attrs
from
sglang.srt.utils
import
get_bool_env_var
,
mxfp_supported
,
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
StandardDispatchOutput
,
)
from
sglang.srt.layers.quantization.quark.quark
import
QuarkConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"QuarkMoEMethod"
,
"QuarkW4A4MXFp4MoEMethod"
]
__all__
=
[
"QuarkMoEMethod"
,
"QuarkW4A4MXFp4MoEMethod"
]
...
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
...
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE
=
32
OCP_MX_BLOCK_SIZE
=
32
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization
import
QuarkConfig
class
QuarkMoEMethod
:
class
QuarkMoEMethod
(
FusedMoEMethodBase
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
def
__init__
(
self
,
quant_config
:
QuarkConfig
):
self
.
quant_config
=
quant_config
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
@
staticmethod
@
staticmethod
def
get_moe_method
(
def
get_moe_method
(
quant_config
:
"
QuarkConfig
"
,
# type: ignore # noqa E501 # noqa F821
quant_config
:
QuarkConfig
,
# type: ignore # noqa E501 # noqa F821
module
:
torch
.
nn
.
Module
,
module
:
torch
.
nn
.
Module
,
layer_name
:
str
,
layer_name
:
str
,
)
->
"QuarkMoEMethod"
:
)
->
"QuarkMoEMethod"
:
...
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
...
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer
.
w2_weight_scale
.
data
=
w2_weight_scale
.
view
(
s0
,
s1
,
-
1
)
layer
.
w2_weight_scale
.
data
=
w2_weight_scale
.
view
(
s0
,
s1
,
-
1
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
moe_runner_config
=
self
.
moe_runner_config
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
return
fused_moe
(
output
=
fused_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
...
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
),
),
doweight_stage1
=
False
,
doweight_stage1
=
False
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/layers/quantization/unquant.py
View file @
3fa62da7
...
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
...
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
LinearMethodBase
,
LinearMethodBase
,
...
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
...
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
)
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.topk
import
TopKOutput
CombineInput
,
StandardDispatchOutput
,
)
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
...
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
with_bias
:
bool
=
False
,
with_bias
:
bool
=
False
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
...
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
.
with_bias
=
with_bias
self
.
with_bias
=
with_bias
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight_n
,
w13_weight_k
=
2
*
intermediate_size
,
hidden_size
w13_weight_n
,
w13_weight_k
=
2
*
intermediate_size
_per_partition
,
hidden_size
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
w13_weight_n
,
w13_weight_k
=
w13_weight_k
,
w13_weight_n
w13_weight_n
,
w13_weight_k
=
w13_weight_k
,
w13_weight_n
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if
self
.
with_bias
:
if
self
.
with_bias
:
w13_weight_bias
=
torch
.
nn
.
Parameter
(
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
...
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight_n
,
w2_weight_k
=
(
w2_weight_n
,
w2_weight_k
=
(
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
_per_partition
,
)
)
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
w2_weight_n
,
w2_weight_k
=
w2_weight_k
,
w2_weight_n
w2_weight_n
,
w2_weight_k
=
w2_weight_k
,
w2_weight_n
...
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return
return
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
layer
=
layer
,
topk_output
=
topk_output
,
dispatch_output
=
dispatch_output
,
moe_runner_config
=
moe_runner_config
,
)
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
moe_runner_config
=
self
.
moe_runner_config
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
if
self
.
with_bias
:
assert
self
.
triton_kernel_moe_with_bias_forward
is
not
None
assert
self
.
triton_kernel_moe_with_bias_forward
is
not
None
return
self
.
triton_kernel_moe_with_bias_forward
(
output
=
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
...
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
)
else
:
else
:
assert
self
.
triton_kernel_moe_forward
is
not
None
assert
self
.
triton_kernel_moe_forward
is
not
None
return
self
.
triton_kernel_moe_forward
(
output
=
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
moe_runner_config
=
moe_runner_config
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
else
:
else
:
if
_use_aiter
:
if
_use_aiter
:
assert
not
moe_runner_config
.
no_combine
,
"unsupported"
assert
not
moe_runner_config
.
no_combine
,
"unsupported"
...
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
=
torch
.
ones_like
(
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
topk_weights
,
dtype
=
torch
.
float32
)
# topk_weights must be FP32 (float32)
)
# topk_weights must be FP32 (float32)
return
fused_moe
(
output
=
fused_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else
ActivationType
.
Gelu
else
ActivationType
.
Gelu
),
),
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
else
:
else
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_experts
,
)
return
fused_experts
(
quant_info
=
TritonMoeQuantInfo
(
hidden_states
=
x
,
w13_weight
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
b13
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
b1
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
b2
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
b2
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
def
forward_cpu
(
def
forward_cpu
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
moe_runner_config
=
self
.
moe_runner_config
assert
(
assert
(
moe_runner_config
.
activation
==
"silu"
moe_runner_config
.
activation
==
"silu"
),
f
"activation =
{
moe_runner_config
.
activation
}
is not supported."
),
f
"activation =
{
moe_runner_config
.
activation
}
is not supported."
...
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x
,
topk_weights
=
apply_topk_weights_cpu
(
x
,
topk_weights
=
apply_topk_weights_cpu
(
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
output
=
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
None
,
# a2_scale
None
,
# a2_scale
True
,
# is_vnni
True
,
# is_vnni
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
else
:
else
:
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
return
moe_forward_native
(
output
=
moe_forward_native
(
layer
,
layer
,
x
,
x
,
topk_output
,
topk_output
,
moe_runner_config
,
moe_runner_config
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
def
forward_npu
(
def
forward_npu
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
return
moe_forward_native
(
output
=
moe_forward_native
(
layer
,
layer
,
x
,
x
,
topk_output
,
topk_output
,
moe_runner_config
,
self
.
moe_runner_config
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
CombineInput
:
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
forward_native
=
forward_cpu
forward_native
=
forward_cpu
python/sglang/srt/layers/quantization/w4afp8.py
View file @
3fa62da7
...
@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
...
@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
QuantizationConfig
,
QuantizationConfig
,
...
@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
...
@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
StandardDispatchOutput
,
)
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer
:
EPMoE
,
layer
:
EPMoE
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
...
@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
num_experts
,
intermediate_size
*
2
,
intermediate_size
_per_partition
*
2
,
hidden_size
//
2
,
hidden_size
//
2
,
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
),
),
...
@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch
.
empty
(
torch
.
empty
(
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size
//
2
,
intermediate_size
_per_partition
//
2
,
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num_experts
,
num_experts
,
2
*
intermediate_size
,
2
*
intermediate_size
_per_partition
,
hidden_size
//
self
.
quant_config
.
group_size
,
hidden_size
//
self
.
quant_config
.
group_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
),
),
...
@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch
.
zeros
(
torch
.
zeros
(
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size
//
self
.
quant_config
.
group_size
,
intermediate_size
_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
)
self
.
c_strides1
=
torch
.
full
(
self
.
c_strides1
=
torch
.
full
(
(
num_experts
,
3
),
(
num_experts
,
3
),
2
*
intermediate_size
,
2
*
intermediate_size
_per_partition
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
self
.
a_strides2
=
torch
.
full
(
self
.
a_strides2
=
torch
.
full
(
(
num_experts
,
3
),
(
num_experts
,
3
),
intermediate_size
,
intermediate_size
_per_partition
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
...
@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
)
layer
.
w2_input_scale
=
Parameter
(
new_w2_input_scale
,
requires_grad
=
False
)
layer
.
w2_input_scale
=
Parameter
(
new_w2_input_scale
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
:
EPMoE
,
layer
:
EPMoE
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
StandardTopKOut
put
,
)
->
CombineIn
put
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
# TODO(ch-wan): move it out of this clas
s
x
=
dispatch_output
.
hidden_state
s
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
local_topk_ids
=
topk_ids
local_topk_ids
=
topk_ids
...
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
,
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
layer
.
w2_input_scale
,
)
)
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
if
self
.
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
output
*=
self
.
moe_runner_config
.
routed_scaling_factor
return
output
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
3fa62da7
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import (
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
StandardDispatchOutput
,
)
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
...
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
...
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
...
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
...
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
fp8_dtype
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
fp8_dtype
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
...
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
fp8_dtype
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
fp8_dtype
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
...
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
...
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
StandardTopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
return
fused_experts
(
quant_info
=
TritonMoeQuantInfo
(
x
,
w13_weight
=
layer
.
w13_weight
,
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale
)
,
w1
3
_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
(
layer
.
w2_weight_scale
)
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1
3
_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
3fa62da7
...
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
...
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.parameter
import
(
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
...
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
...
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
)
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher
import
(
from
sglang.srt.layers.moe.topk
import
TopKOutput
CombineInput
,
StandardDispatchOutput
,
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
...
@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
...
@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
torch
.
int8
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
torch
.
int8
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
...
@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
if
use_intel_amx_backend
(
layer
):
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
x
,
topk_weights
=
apply_topk_weights_cpu
(
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
self
.
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
)
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
output
=
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
,
# a2_scale
layer
.
w2_input_scale
,
# a2_scale
True
,
# is_vnni
True
,
# is_vnni
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
return
fused_experts
(
quant_info
=
TritonMoeQuantInfo
(
x
,
w13_weight
=
layer
.
w13_weight
,
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
use_int8_w8a8
=
True
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale
)
,
w1
3
_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
(
layer
.
w2_weight_scale
)
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1
3
_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
NPU_W8A8LinearMethodImpl
:
class
NPU_W8A8LinearMethodImpl
:
...
@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
...
@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
)
->
None
:
)
->
None
:
...
@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
...
@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
# weight
# weight
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
torch
.
int8
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
torch
.
int8
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# scale
# scale
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
...
@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
...
@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# offset
# offset
w13_weight_offset
=
torch
.
nn
.
Parameter
(
w13_weight_offset
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight_offset"
,
w13_weight_offset
)
layer
.
register_parameter
(
"w13_weight_offset"
,
w13_weight_offset
)
...
@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
...
@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight_offset
.
data
.
squeeze
(
-
1
).
contiguous
(),
requires_grad
=
False
layer
.
w2_weight_offset
.
data
.
squeeze
(
-
1
).
contiguous
(),
requires_grad
=
False
)
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
def
apply
(
self
,
self
,
layer
,
layer
,
x
,
dispatch_output
:
StandardDispatchOutput
,
topk_output
:
TopKOutput
,
)
->
CombineInput
:
moe_runner_config
:
MoeRunnerConfig
,
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
)
->
torch
.
Tensor
:
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_weights
=
topk_weights
.
to
(
x
.
dtype
)
topk_weights
=
topk_weights
.
to
(
x
.
dtype
)
return
npu_fused_experts
(
output
=
npu_fused_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
w13
=
layer
.
w13_weight
,
w13
=
layer
.
w13_weight
,
w13_scale
=
layer
.
w13_weight_scale
,
w13_scale
=
layer
.
w13_weight_scale
,
...
@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
...
@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
top_k
=
topk_ids
.
shape
[
1
],
top_k
=
topk_ids
.
shape
[
1
],
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
python/sglang/srt/managers/schedule_batch.py
View file @
3fa62da7
...
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
ScheduleBatchDisaggregationDecodeMixin
,
)
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.moe
import
is_tbo_enabled
from
sglang.srt.mem_cache.allocator
import
(
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
BaseTokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
...
...
python/sglang/srt/model_loader/__init__.py
View file @
3fa62da7
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.model_loader.loader
import
BaseModelLoader
,
get_model_loader
from
sglang.srt.model_loader.loader
import
BaseModelLoader
,
get_model_loader
from
sglang.srt.model_loader.utils
import
(
from
sglang.srt.model_loader.utils
import
(
get_architecture_class_name
,
get_architecture_class_name
,
get_model_architecture
,
get_model_architecture
,
)
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
def
get_model
(
def
get_model
(
*
,
*
,
...
...
python/sglang/srt/model_loader/loader.py
View file @
3fa62da7
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
from
__future__
import
annotations
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
collections
import
concurrent
import
concurrent
...
@@ -14,7 +16,17 @@ import time
...
@@ -14,7 +16,17 @@ import time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
,
)
import
huggingface_hub
import
huggingface_hub
import
numpy
as
np
import
numpy
as
np
...
@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
...
@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.connector
import
(
from
sglang.srt.connector
import
(
ConnectorType
,
ConnectorType
,
create_remote_connector
,
create_remote_connector
,
...
@@ -39,7 +49,6 @@ from sglang.srt.distributed import (
...
@@ -39,7 +49,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_loader.utils
import
(
from
sglang.srt.model_loader.utils
import
(
get_model_architecture
,
get_model_architecture
,
post_load_weights
,
post_load_weights
,
...
@@ -70,6 +79,11 @@ from sglang.srt.utils import (
...
@@ -70,6 +79,11 @@ from sglang.srt.utils import (
set_weight_attrs
,
set_weight_attrs
,
)
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
...
...
python/sglang/test/test_cutlass_moe.py
View file @
3fa62da7
...
@@ -9,6 +9,7 @@ from transformers import AutoConfig
...
@@ -9,6 +9,7 @@ from transformers import AutoConfig
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.moe_runner.base
import
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.base
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
...
@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False):
...
@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2
,
problem_sizes2
,
)
)
topk_output
=
StandardTopKOutput
(
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
router_logits
=
torch
.
randn
(
(
batch_size
,
topk
),
device
=
topk_weights
.
device
,
dtype
=
dtype
),
)
moe_runner_config
=
MoeRunnerConfig
(
num_experts
=
E
,
topk
=
topk
,
hidden_size
=
H
,
shard_intermediate_size
=
I
,
dtype
=
dtype
,
block_shape
=
block_shape
,
activation
=
"silu"
,
inplace
=
False
,
)
# Note: Triton expects non-transposed weights
# Note: Triton expects non-transposed weights
moe_config
=
MoeRunnerConfig
(
inplace
=
False
)
triton_lambda
=
lambda
:
fused_experts
(
triton_lambda
=
lambda
:
fused_experts
(
x
,
x
,
w1
,
w1
,
w2
,
w2
,
(
topk_
weights
,
topk_ids
,
"dummy"
)
,
topk_
output
,
moe_config
,
moe_
runner_
config
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
...
@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
...
@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
x
,
x
,
w1
,
# Original shape
w1
,
# Original shape
w2
,
# Original shape
w2
,
# Original shape
(
topk_
weights
,
topk_ids
,
"dummy"
)
,
topk_
output
,
moe_config
,
moe_
runner_
config
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
...
...
test/srt/test_mla_deepseek_v3.py
View file @
3fa62da7
import
os
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase):
...
@@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
class
TestMLADeepseekV3DisableFusedFunc
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC"
]
=
"1"
cls
.
model
=
"lmsys/sglang-ci-dsv3-test"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--chunked-prefill-size"
,
"256"
]
if
is_cuda
():
other_args
.
extend
([
"--cuda-graph-max-bs"
,
"2"
])
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
@
unittest
.
skipIf
(
is_hip
(),
"FA is not available."
)
@
unittest
.
skipIf
(
is_hip
(),
"FA is not available."
)
class
TestMLADeepseekV3Fa3Fp8Kvcache
(
CustomTestCase
):
class
TestMLADeepseekV3Fa3Fp8Kvcache
(
CustomTestCase
):
@
classmethod
@
classmethod
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment