Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bf0f448f
Unverified
Commit
bf0f448f
authored
Jul 27, 2025
by
Cheng Wan
Committed by
GitHub
Jul 27, 2025
Browse files
[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)
parent
36d6f0ba
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
221 additions
and
590 deletions
+221
-590
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+87
-217
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+31
-43
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+25
-247
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+10
-66
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+68
-17
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
bf0f448f
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
bf0f448f
...
...
@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
enable_ep_moe
:
Optional
[
bool
]
=
False
,
skip_quant
:
Optional
[
bool
]
=
False
,
):
super
().
__init__
()
...
...
@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module):
self
.
enable_flashinfer_cutlass_moe
=
enable_flashinfer_cutlass_moe
if
enable_ep_moe
:
assert
(
self
.
enable_flashinfer_cutlass_moe
),
"FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
self
.
ep_size
=
self
.
tp_size
self
.
ep_rank
=
self
.
tp_rank
self
.
tp_size
=
1
...
...
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
self
.
expert_map
=
torch
.
full
((
self
.
num_experts
,),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
assert
num_experts
%
self
.
ep_size
==
0
self
.
local_
num_
experts
=
num_experts
//
self
.
ep_size
self
.
num_
local_experts
=
num_experts
//
self
.
ep_size
self
.
expert_map
[
self
.
ep_rank
*
self
.
local_
num_
experts
:
(
self
.
ep_rank
+
1
)
*
self
.
local_
num_
experts
]
=
torch
.
arange
(
0
,
self
.
local_
num_
experts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
*
self
.
num_
local_experts
:
(
self
.
ep_rank
+
1
)
*
self
.
num_
local_experts
]
=
torch
.
arange
(
0
,
self
.
num_
local_experts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
else
:
self
.
ep_size
=
1
self
.
ep_rank
=
0
self
.
local_
num_
experts
=
num_experts
self
.
num_
local_experts
=
num_experts
self
.
routed_scaling_factor
=
routed_scaling_factor
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
...
...
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
not
_is_cpu
and
global_server_args_dict
[
"enable_triton_kernel_moe"
]
)
if
skip_quant
:
return
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
...
...
@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
self
.
local_
num_
experts
,
num_experts
=
self
.
num_
local_experts
,
hidden_size
=
hidden_size
,
# FIXME: figure out which intermediate_size to use
intermediate_size
=
self
.
intermediate_size_per_partition
,
...
...
@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module):
if
expert_id
==
-
1
:
return
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
weight_name
=
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
def
_weight_loader_impl
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
# TP rank is set to 0 if EP is enabled
tp_rank
=
0
if
self
.
ep_size
>
1
else
get_tensor_model_parallel_rank
()
...
...
@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module):
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if
getattr
(
self
,
"use_flashinfer_trtllm_moe"
,
False
):
shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
WEIGHT_SCALE_SUPPORTED
=
[
e
.
value
for
e
in
FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
...
...
@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module):
(
"w3"
,
ckpt_up_proj_name
),
]
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
(
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
python/sglang/srt/layers/quantization/fp8.py
View file @
bf0f448f
...
...
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
...
...
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
elif
isinstance
(
layer
,
EPMoE
):
return
Fp8EPMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# merged w13 weights and generate a single scaling factor.
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
layer
.
num_local_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
,
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
for
expert
in
range
(
layer
.
num_
local_
experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
...
...
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
for
expert_id
in
range
(
layer
.
num_
local_
experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
...
...
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
for
expert_id
in
range
(
layer
.
num_
local_
experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
...
...
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
for
expert_id
in
range
(
layer
.
num_
local_
experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
...
...
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
isinstance
(
layer
,
EPMoE
):
layer
.
w13_weight_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
)
layer
.
w2_weight_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
)
return
layer
.
run_moe
(
hidden_states
=
x
,
topk_output
=
topk_output
,
)
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
...
...
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
None
class
Fp8EPMoEMethod
(
Fp8MoEMethod
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
def
create_weights
(
self
,
layer
:
Module
,
num_experts_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
block_quant
:
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if
intermediate_size
%
block_n
!=
0
:
raise
ValueError
(
f
"The output_size of gate's and up's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
if
tp_size
>
1
:
# Required by row parallel
if
intermediate_size
%
block_k
!=
0
:
raise
ValueError
(
f
"The input_size of down's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts_per_partition
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts_per_partition
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
if
self
.
block_quant
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
2
*
((
intermediate_size
+
block_n
-
1
)
//
block_n
),
(
hidden_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
(
hidden_size
+
block_n
-
1
)
//
block_n
,
(
intermediate_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
else
:
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
if
self
.
block_quant
else
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts_per_partition
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
,
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts_per_partition
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
max
(
layer
.
w13_weight_scale
,
dim
=
1
).
values
,
requires_grad
=
False
,
)
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
weight_scale
=
layer
.
w13_weight_scale_inv
,
input_scale
=
None
,
)
w2_weight
,
w2_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w2_weight
,
weight_scale
=
layer
.
w2_weight_scale_inv
,
input_scale
=
None
,
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
layer
.
w13_input_scale
=
None
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w2_input_scale
=
None
if
_use_aiter
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
bf0f448f
...
...
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
...
...
@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
if
isinstance
(
layer
,
EPMoE
):
return
layer
.
run_moe
(
hidden_states
=
x
,
topk_output
=
topk_output
,
)
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
...
...
@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
forward_native
=
forward_cpu
class
UnquantizedEPMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts_per_partition
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts_per_partition
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# scale
layer
.
register_parameter
(
"w13_input_scale"
,
None
)
layer
.
register_parameter
(
"w13_weight_scale"
,
None
)
ones_tensor
=
torch
.
ones
(
num_experts_per_partition
,
dtype
=
torch
.
float32
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
ones_tensor
,
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
ones_tensor
,
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
python/sglang/srt/layers/quantization/w4afp8.py
View file @
bf0f448f
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
...
...
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
,
TopKOutput
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
Fused
MoE
):
elif
isinstance
(
layer
,
EP
MoE
):
return
W4AFp8MoEMethod
(
self
)
return
None
...
...
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def
create_weights
(
self
,
layer
:
Module
,
num_experts
_per_partition
:
int
,
layer
:
EPMoE
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
...
...
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
_per_partition
,
num_experts
,
intermediate_size
*
2
,
hidden_size
//
2
,
dtype
=
torch
.
int8
,
...
...
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
_per_partition
,
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
,
...
...
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
_per_partition
,
num_experts
,
2
*
intermediate_size
,
hidden_size
//
self
.
quant_config
.
group_size
,
dtype
=
torch
.
float32
,
...
...
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
_per_partition
,
num_experts
,
hidden_size
,
intermediate_size
//
self
.
quant_config
.
group_size
,
dtype
=
torch
.
float32
,
...
...
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Input scales
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
((
num_experts
_per_partition
,
2
),
dtype
=
torch
.
bfloat16
),
torch
.
ones
((
num_experts
,
2
),
dtype
=
torch
.
bfloat16
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
_per_partition
,
dtype
=
torch
.
bfloat16
),
torch
.
ones
(
num_experts
,
dtype
=
torch
.
bfloat16
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
...
...
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
device
=
layer
.
w13_weight
.
device
self
.
a_strides1
=
torch
.
full
(
(
num_experts
_per_partition
,
3
),
(
num_experts
,
3
),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides1
=
torch
.
full
(
(
num_experts
_per_partition
,
3
),
(
num_experts
,
3
),
2
*
intermediate_size
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
a_strides2
=
torch
.
full
(
(
num_experts
_per_partition
,
3
),
(
num_experts
,
3
),
intermediate_size
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides2
=
torch
.
full
(
(
num_experts
_per_partition
,
3
),
(
num_experts
,
3
),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
,
...
...
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self
.
s_strides2
=
self
.
c_strides2
self
.
expert_offsets
=
torch
.
empty
(
(
num_experts
_per_partition
+
1
),
dtype
=
torch
.
int32
,
device
=
device
(
num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
problem_sizes1
=
torch
.
empty
(
(
num_experts
_per_partition
,
3
),
dtype
=
torch
.
int32
,
device
=
device
(
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
problem_sizes2
=
torch
.
empty
(
(
num_experts
_per_partition
,
3
),
dtype
=
torch
.
int32
,
device
=
device
(
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
return
...
...
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
[
w2_input_scale_max
],
dtype
=
dtype
,
device
=
device
)
layer
.
w2_input_scale
=
Parameter
(
new_w2_input_scale
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
EPMoE
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
)
->
torch
.
Tensor
:
# TODO(ch-wan): move it out of this class
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
topk_ids
,
topk_weights
,
_
=
topk_output
local_topk_ids
=
topk_ids
if
layer
.
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
layer
.
expert_map
[
topk_ids
]
!=
layer
.
num_experts
,
layer
.
expert_map
[
topk_ids
],
layer
.
num_experts
,
)
return
cutlass_w4a8_moe
(
layer
.
start_expert_id
,
layer
.
end_expert_id
,
layer
.
num_experts
,
hidden_states
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
topk_weights
,
topk_ids
,
local_topk_ids
,
self
.
a_strides1
,
self
.
b_strides1
,
self
.
c_strides1
,
self
.
a_strides2
,
self
.
b_strides2
,
self
.
c_strides2
,
self
.
s_strides13
,
self
.
s_strides2
,
self
.
expert_offsets
,
self
.
problem_sizes1
,
self
.
problem_sizes2
,
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment