Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
67510e01
Commit
67510e01
authored
Oct 21, 2025
by
lizhigong
Committed by
maxiao1
Oct 25, 2025
Browse files
adaptation part w4A8 quantization
(cherry picked from commit
68277eac
)
parent
32b1ccaf
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
777 additions
and
0 deletions
+777
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-0
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+408
-0
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+272
-0
python/sglang/srt/layers/quantization/w4a8_utils.py
python/sglang/srt/layers/quantization/w4a8_utils.py
+92
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
67510e01
...
@@ -614,6 +614,7 @@ class ModelConfig:
...
@@ -614,6 +614,7 @@ class ModelConfig:
"petit_nvfp4"
,
"petit_nvfp4"
,
"quark"
,
"quark"
,
"mxfp4"
,
"mxfp4"
,
"slimquant_w4a8_marlin"
,
]
]
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"fp8"
,
...
@@ -633,6 +634,7 @@ class ModelConfig:
...
@@ -633,6 +634,7 @@ class ModelConfig:
"qoq"
,
"qoq"
,
"w4afp8"
,
"w4afp8"
,
"petit_nvfp4"
,
"petit_nvfp4"
,
"slimquant_w4a8_marlin"
,
]
]
compatible_quantization_methods
=
{
compatible_quantization_methods
=
{
"modelopt_fp4"
:
[
"modelopt"
],
"modelopt_fp4"
:
[
"modelopt"
],
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
67510e01
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
_is_mxfp_supported
=
mxfp_supported
()
_is_mxfp_supported
=
mxfp_supported
()
...
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -83,6 +84,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8"
:
W4AFp8Config
,
"w4afp8"
:
W4AFp8Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
}
}
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8.py
0 → 100644
View file @
67510e01
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.layers.linear
import
set_weight_attrs
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.base_config
import
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
FusedMoEMethodBase
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
_ColumnvLLMParameter
,
RowvLLMParameter
,
)
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
import
os
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
SlimQuantW4A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
# assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args
# else:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
SlimQuantW4A8Int8MoEMethod
:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
N1
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
TOPK
=
self
.
tritonsingleton
.
topk
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
,
use_int4_w4a8
=
True
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
#warmup
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
0 → 100644
View file @
67510e01
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization.w4a8_utils
import
w4a8_weight_repack_impl
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizeMethodBase
)
from
sglang.srt.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances
=
{}
def
__new__
(
cls
,
device
):
if
device
not
in
cls
.
_instances
:
instance
=
super
().
__new__
(
cls
)
instance
.
_initialized
=
False
cls
.
_instances
[
device
]
=
instance
return
cls
.
_instances
[
device
]
def
__init__
(
self
,
device
):
if
self
.
_initialized
:
return
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
self
.
workspace
=
torch
.
zeros
(
500
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
global_reduce_buffer
=
torch
.
zeros
(
sms
*
6
*
128
*
512
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
_initialized
=
True
def
get_buffers
(
self
):
return
self
.
workspace
,
self
.
global_reduce_buffer
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8MarlinConfig
(
QuantizationConfig
):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8_marlin"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8MarlinConfig"
:
return
cls
()
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8MarlinMoEMethod
:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
python/sglang/srt/layers/quantization/w4a8_utils.py
0 → 100644
View file @
67510e01
import
torch
import
numpy
as
np
try
:
from
lightop
import
awq_marlin_repack_w4a8
use_lightop
=
False
except
Exception
:
use_lightop
=
False
def
unpack_int8_to_int4
(
tensor_int8
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if
tensor_int8
.
dtype
!=
torch
.
int8
:
raise
ValueError
(
"Input tensor must be of type torch.int8"
)
N
,
K_half
=
tensor_int8
.
shape
tensor_uint8
=
tensor_int8
.
to
(
torch
.
uint8
)
high4
=
tensor_uint8
&
0x0F
low4
=
(
tensor_uint8
>>
4
)
&
0x0F
unpacked
=
torch
.
empty
((
N
,
K_half
*
2
),
dtype
=
torch
.
int32
,
device
=
tensor_int8
.
device
)
unpacked
[:,
0
::
2
]
=
low4
.
to
(
torch
.
int32
)
unpacked
[:,
1
::
2
]
=
high4
.
to
(
torch
.
int32
)
return
unpacked
def
get_weight_perms
(
interleave
:
bool
=
True
):
perm
=
[]
for
i
in
range
(
64
):
for
col
in
range
(
4
):
cur_col
=
(
i
%
16
)
*
4
+
col
for
row
in
range
(
8
):
cur_row
=
(
i
//
16
)
*
8
+
row
cur_idx
=
cur_row
*
64
+
cur_col
perm
.
append
(
cur_idx
)
perm
=
np
.
array
(
perm
)
if
interleave
:
interleave
=
np
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
perm
=
perm
.
reshape
((
-
1
,
8
))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_weights
(
q_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
):
size_k
,
size_n
=
q_w
.
shape
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
k_tile
,
size_n
//
n_tile
,
n_tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
size_n
*
k_tile
))
q_w
=
q_w
.
reshape
((
-
1
,
weight_perm
.
numel
()))[:,
weight_perm
].
reshape
(
q_w
.
shape
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
contiguous
().
to
(
torch
.
int32
)
M
,
N
=
q_w
.
shape
assert
N
%
pack_factor
==
0
,
f
"size_n (
{
N
}
) must be divisible by pack_factor (
{
pack_factor
}
)"
q_packed
=
torch
.
zeros
((
M
,
N
//
pack_factor
),
dtype
=
torch
.
int32
,
device
=
orig_device
)
for
i
in
range
(
pack_factor
):
q_packed
+=
q_w
[:,
i
::
pack_factor
]
<<
(
4
*
i
)
return
q_packed
def
w4a8_2_marlin_weight
(
w4a8_w
):
full_w4a8_w
=
unpack_int8_to_int4
(
w4a8_w
)
full_w4a8_w
=
full_w4a8_w
.
T
weight_perm
=
get_weight_perms
()
marlin_q_w
=
marlin_weights
(
full_w4a8_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
)
return
marlin_q_w
def
w4a8_weight_repack_impl
(
input
):
if
use_lightop
:
size_batch
=
input
.
shape
[
0
]
size_n
=
input
.
shape
[
1
]
size_k
=
input
.
shape
[
2
]
*
2
output
=
torch
.
zeros
((
size_batch
,
size_k
//
32
,
size_n
*
4
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
awq_marlin_repack_w4a8
(
input
,
output
,
size_batch
,
size_k
,
size_n
)
else
:
w_marlin_list
=
[]
for
e
in
range
(
input
.
shape
[
0
]):
w_marlin_in
=
w4a8_2_marlin_weight
(
input
[
e
])
w_marlin_list
.
append
(
w_marlin_in
)
output
=
torch
.
stack
(
w_marlin_list
,
dim
=
0
)
return
output
python/sglang/srt/server_args.py
View file @
67510e01
...
@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
...
@@ -93,6 +93,7 @@ QUANTIZATION_CHOICES = [
"w4afp8"
,
"w4afp8"
,
"mxfp4"
,
"mxfp4"
,
"compressed-tensors"
,
# for Ktransformers
"compressed-tensors"
,
# for Ktransformers
"slimquant_w4a8_marlin"
,
]
]
ATTENTION_BACKEND_CHOICES
=
[
ATTENTION_BACKEND_CHOICES
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment