Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4dff7f5
Commit
b4dff7f5
authored
Oct 21, 2025
by
lizhigong
Committed by
maxiao1
Oct 25, 2025
Browse files
adaptation w4A8 quantization
parent
c0352f4a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
764 additions
and
0 deletions
+764
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+31
-0
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+415
-0
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+318
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
b4dff7f5
...
...
@@ -5,6 +5,15 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
...
...
@@ -175,3 +184,25 @@ def mscclpp_allreduce(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
list
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda:0"
,
best_config
:
Optional
[
list
]
=
None
,
repeat
:
Optional
[
int
]
=
2
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
,
repeat
)
\ No newline at end of file
python/sglang/srt/layers/quantization/slimquant_w4a8.py
0 → 100644
View file @
b4dff7f5
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.layers.linear
import
set_weight_attrs
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.base_config
import
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
FusedMoEMethodBase
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
_ColumnvLLMParameter
,
RowvLLMParameter
,
)
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
SlimQuantW4A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
# assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args
# else:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
SlimQuantW4A8Int8MoEMethod
:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
N1
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
TOPK
=
self
.
tritonsingleton
.
topk
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
,
use_int4_w4a8
=
True
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
#warmup
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
0 → 100644
View file @
b4dff7f5
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
sglang.srt.layers.moe.token_dispatcher.base
import
CombineInput
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
,
StandardDispatchOutput
import
torch
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization.w4a8_utils
import
w4a8_weight_repack_impl
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizeMethodBase
)
from
sglang.srt.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances
=
{}
def
__new__
(
cls
,
device
):
if
device
not
in
cls
.
_instances
:
instance
=
super
().
__new__
(
cls
)
instance
.
_initialized
=
False
cls
.
_instances
[
device
]
=
instance
return
cls
.
_instances
[
device
]
def
__init__
(
self
,
device
):
if
self
.
_initialized
:
return
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
self
.
workspace
=
torch
.
zeros
(
500
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
global_reduce_buffer
=
torch
.
zeros
(
sms
*
6
*
128
*
512
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
_initialized
=
True
def
get_buffers
(
self
):
return
self
.
workspace
,
self
.
global_reduce_buffer
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8MarlinConfig
(
QuantizationConfig
):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8_marlin"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8MarlinConfig"
:
return
cls
()
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8MarlinMoEMethod
:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
intermediate_size_per_partition
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
self
.
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
output
=
fused_experts_impl_w4a8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
layer
.
moe_runner_config
.
activation
,
expert_map
=
layer
.
expert_map_gpu
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
moe_runner_config
.
num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
# def _apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# #renormalize: bool,
# #use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# # Expert selection
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# #use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# #renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# e_score_correction_bias=e_score_correction_bias,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate
# )
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
# return fused_experts_impl_w4a8_marlin(
# x,
# layer.w13_weight,
# layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# workspace=workspace,
# global_reduce_buffer=global_reduce_buffer,
# inplace=True,
# use_int4_w4a8=True,
# per_channel_quant=True,
# activation=activation,
# expert_map=expert_map,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# w1_scale=(layer.w13_weight_scale),
# w2_scale=(layer.w2_weight_scale),
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment