Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c28ad199
Unverified
Commit
c28ad199
authored
Jul 17, 2025
by
Peng Zhang
Committed by
GitHub
Jul 16, 2025
Browse files
[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)
parent
570d3343
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1478 additions
and
616 deletions
+1478
-616
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
+4
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-4
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+491
-119
python/sglang/srt/layers/quantization/marlin_utils.py
python/sglang/srt/layers/quantization/marlin_utils.py
+781
-0
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+30
-0
python/sglang/srt/layers/quantization/quant_utils.py
python/sglang/srt/layers/quantization/quant_utils.py
+0
-166
python/sglang/srt/layers/quantization/scalar_type.py
python/sglang/srt/layers/quantization/scalar_type.py
+0
-0
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+162
-1
sgl-kernel/python/sgl_kernel/fused_moe.py
sgl-kernel/python/sgl_kernel/fused_moe.py
+2
-1
sgl-kernel/tests/test_marlin_repack.py
sgl-kernel/tests/test_marlin_repack.py
+2
-4
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+4
-5
test/srt/test_int4_kernel.py
test/srt/test_int4_kernel.py
+0
-301
test/srt/test_w4a8.py
test/srt/test_w4a8.py
+0
-14
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/__init__.py
View file @
c28ad199
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
import
sglang.srt.layers.moe.fused_moe_triton.fused_moe
# noqa
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_experts
,
get_config_file_name
,
moe_align_block_size
,
try_get_optimal_moe_config
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FusedMoE
,
...
...
@@ -37,4 +38,6 @@ __all__ = [
"fused_moe"
,
"fused_experts"
,
"get_config_file_name"
,
"moe_align_block_size"
,
"try_get_optimal_moe_config"
,
]
python/sglang/srt/layers/quantization/__init__.py
View file @
c28ad199
...
...
@@ -22,10 +22,6 @@ try:
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
,
)
...
...
@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
(
GPTQConfig
,
GPTQLinearMethod
,
GPTQMarlinConfig
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.quantization.modelopt_quant
import
(
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
c28ad199
import
logging
from
dataclasses
import
dataclass
from
fractions
import
Fraction
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.layers.linear
import
LinearBase
,
set_weight_attrs
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
,
set_weight_attrs
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
,
permute_param_layout_
,
)
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.utils
import
replace_parameter
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
from
sglang.srt.layers.quantization.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_marlin_supported
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
marlin_zero_points
,
verify_marlin_supported
,
)
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.srt.layers.quantization.utils
import
replace_parameter
,
unpack_cols
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
GPTQMarlinLinearMethod
,
marlin_moe_permute_scales
,
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
)
from
vllm.scalar_type
import
scalar_types
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
Fals
e
ops
=
Non
e
GPTQLinearMethod
=
MarlinLinearMethod
=
Any
from
sglang.srt.utils
import
is_cuda
FusedMoEMethodBase
=
QuantizeMethodBase
_is_cuda
=
is_cuda
()
class
scalar_types
:
uint4b8
=
"uint4b8"
uint8b128
=
"uint8b128"
if
_is_cuda
:
from
sgl_kernel
import
fused_marlin_moe
FusedMoEMethodBase
=
QuantizeMethodBase
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
)
def
gptq_marlin_moe_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
,
)
->
torch
.
Tensor
:
num_experts
=
b_q_weight
.
shape
[
0
]
assert
size_k
%
16
==
0
output
=
torch
.
empty
(
(
num_experts
,
size_k
//
16
,
size_n
*
(
num_bits
//
2
)),
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
sgl_kernel
.
gptq_marlin_repack
(
b_q_weight
[
e
],
perm
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
@
dataclass
class
MarlinLinearLayerConfig
:
full_weight_shape
:
tuple
[
int
,
int
]
# [in, out]
partition_weight_shape
:
tuple
[
int
,
int
]
weight_type
:
ScalarType
act_type
:
torch
.
dtype
group_size
:
int
zero_points
:
bool
has_g_idx
:
bool
class
GPTQConfig
(
QuantizationConfig
):
"""Config class for GPTQ.
...
...
@@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
GPTQ
LinearMethod
]:
)
->
Optional
[
"
LinearMethod
Base"
]:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
get_linear_quant_method
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
if
isinstance
(
layer
,
LinearBase
):
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
elif
isinstance
(
layer
,
FusedMoE
):
raise
TypeError
(
"GPTQ Method does not support MoE, please use gptq_marlin"
)
return
None
class
GPTQMarlinConfig
(
QuantizationConfig
):
...
...
@@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig):
if
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
# if layer.num_experts > 32:
# # For MoEs with many experts the moe_wna16 kernel is faster
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
# else:
# return GPTQMarlinMoEMethod(self)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinLinearMethod
)
@
classmethod
...
...
@@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig):
if
(
num_bits
,
sym
)
not
in
cls
.
TYPE_MAP
:
return
False
assert
(
VLLM_AVAILABLE
),
"vllm is not installed, to use gptq_marlin, please install vllm"
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
)
class
MarlinConfig
(
QuantizationConfig
):
"""
Config class for Marlin
.
class
GPTQLinearMethod
(
LinearMethodBase
):
"""
Linear method for GPTQ
.
Reference: https://github.com/IST-DASLab/marlin/tree/master
Args:
quant_config: The GPTQ quantization config.
"""
def
__init__
(
def
__init__
(
self
,
quant_config
:
GPTQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
group_size
:
int
,
lm_head_quantized
:
bool
,
)
->
None
:
# Group size for the quantization.
self
.
group_size
=
group_size
self
.
lm_head_quantized
=
lm_head_quantized
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f
"
{
self
.
group_size
}
"
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
4
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
self
.
use_shuffle
=
True
scale_and_zero_size
=
input_size
//
group_size
scale_and_zero_input_dim
=
None
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
if
self
.
quant_config
.
desc_act
:
self
.
use_shuffle
=
False
else
:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size
=
input_size_per_partition
//
group_size
scale_and_zero_input_dim
=
0
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
# Tile size used by marlin kernels.
self
.
tile_size
=
16
g_idx
=
RowvLLMParameter
(
data
=
torch
.
tensor
(
[
i
//
self
.
quant_config
.
group_size
for
i
in
range
(
input_size_per_partition
)
],
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
,
)
qzeros_args
=
{
"data"
:
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
"weight_loader"
:
weight_loader
,
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
,
}
if
scale_and_zero_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
# Min out_features dim
self
.
min_n_threads
=
64
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
# Min in_features dim
self
.
min_k_threads
=
128
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"scales"
,
scales
)
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
16
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# for torch.compile
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
torch
.
nn
.
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
self
.
use_shuffle
:
if
self
.
quant_config
.
desc_act
:
layer
.
g_idx
.
data
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
else
:
layer
.
g_idx
.
data
=
torch
.
empty
(
(
0
,),
dtype
=
torch
.
int
,
device
=
layer
.
g_idx
.
device
)
ops
.
gptq_shuffle
(
layer
.
qweight
,
layer
.
g_idx
,
self
.
quant_config
.
weight_bits
)
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
qweight
.
shape
[
-
1
],)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
output
=
ops
.
gptq_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
qzeros
,
layer
.
scales
,
layer
.
g_idx
,
self
.
use_shuffle
,
self
.
quant_config
.
weight_bits
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
return
output
.
reshape
(
out_shape
)
def
__repr__
(
self
)
->
str
:
return
(
f
"MarlinConfig(group_size=
{
self
.
group_size
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_config
.
quant_type
,
group_size
=
self
.
quant_config
.
group_size
,
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"marlin"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
kernel_config
=
MarlinLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
(
input_size_per_partition
,
output_size_per_partition
,
),
weight_type
=
self
.
quant_config
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
quant_config
.
group_size
,
zero_points
=
False
,
has_g_idx
=
self
.
quant_config
.
desc_act
,
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
self
.
quant_config
.
group_size
,
is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
else
:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim
=
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Quantized weights
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
# Activation order
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
,
)
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
qzeros_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
"weight_loader"
:
weight_loader
,
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
,
}
if
scales_and_zp_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
group_size
,
lm_head_quantized
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
is_marlin_format
=
check_marlin_format
(
hf_quant_cfg
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"marlin"
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
getattr
(
layer
,
"qweight"
).
device
c
=
self
.
kernel_config
check_marlin_supports_shape
(
c
.
partition_weight_shape
[
1
],
# out_features
c
.
partition_weight_shape
[
0
],
# in_features
c
.
full_weight_shape
[
0
],
# in_features
c
.
group_size
,
)
if
is_marlin_format
and
is_valid_user_quant
:
msg
=
"The model is serialized in {} format. Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()
row_parallel
=
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]
self
.
is_k_full
=
marlin_is_k_full
(
c
.
has_g_idx
,
row_parallel
)
# Allocate marlin workspace.
self
.
workspace
=
marlin_make_workspace
(
device
)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
self
.
w_q_name
=
"qweight"
self
.
w_s_name
=
"scales"
self
.
w_zp_name
=
"qzeros"
self
.
w_gidx_name
=
"g_idx"
def
_transform_param
(
layer
:
torch
.
nn
.
Module
,
name
:
Optional
[
str
],
fn
:
Callable
)
->
None
:
if
name
is
not
None
and
getattr
(
layer
,
name
,
None
)
is
not
None
:
old_param
=
getattr
(
layer
,
name
)
new_param
=
fn
(
old_param
)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter
(
layer
,
name
,
torch
.
nn
.
Parameter
(
new_param
.
data
,
requires_grad
=
False
)
)
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
torch
.
ops
.
sgl_kernel
.
gptq_marlin_repack
(
x
.
data
.
contiguous
(),
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
,
)
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
marlin_permute_scales
(
x
.
data
.
contiguous
(),
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
group_size
=
c
.
group_size
,
)
return
x
return
None
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
)
)
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
MarlinLinearMethod
]:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
if
c
.
zero_points
:
grouped_k
=
(
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
if
c
.
group_size
!=
-
1
else
1
)
_transform_param
(
layer
,
self
.
w_zp_name
,
lambda
x
:
marlin_zero_points
(
unpack_cols
(
x
.
t
(),
c
.
weight_type
.
size_bits
,
grouped_k
,
c
.
partition_weight_shape
[
1
],
),
size_k
=
grouped_k
,
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
,
),
)
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
MarlinLinearMethod
(
self
)
return
None
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
c
=
self
.
kernel_config
def
_get_weight_params
(
layer
:
torch
.
nn
.
Module
,
)
->
tuple
[
torch
.
Tensor
,
# w_q
torch
.
Tensor
,
# w_s
Optional
[
torch
.
Tensor
],
# w_zp,
Optional
[
torch
.
Tensor
],
# w_gidx
]:
return
(
getattr
(
layer
,
self
.
w_q_name
),
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
self
.
w_zp_name
or
""
,
None
),
getattr
(
layer
,
self
.
w_gidx_name
or
""
,
None
),
)
w_q
,
w_s
,
w_zp
,
w_gidx
=
_get_weight_params
(
layer
)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return
apply_gptq_marlin_linear
(
input
=
x
,
weight
=
w_q
,
weight_scale
=
w_s
,
weight_zp
=
w_zp
,
# type: ignore
g_idx
=
w_gidx
,
# type: ignore
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
self
.
workspace
,
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
is_k_full
=
self
.
is_k_full
,
bias
=
bias
,
)
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
...
...
@@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
intermediate_size
=
extra_weight_attrs
.
pop
(
"intermediate_size"
)
self
.
is_k_full
=
(
not
self
.
quant_config
.
desc_act
)
or
(
...
...
@@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
requires_grad
=
False
,
)
# Repack weights
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w13_qweight
=
gptq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size
_bits
,
self
.
quant_config
.
weight
_bits
,
)
replace_parameter
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size
_bits
,
self
.
quant_config
.
weight
_bits
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
...
...
@@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
scoring_func
==
"softmax"
),
"Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
...
...
@@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
correction_bias
=
e_score_correction_bias
,
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
...
...
@@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
quant_type_id
=
self
.
quant_config
.
quant_type
.
id
,
num_bits
=
self
.
quant_config
.
weight_bits
,
is_k_full
=
self
.
is_k_full
,
).
to
(
orig_dtype
)
python/sglang/srt/layers/quantization/marlin_utils.py
0 → 100644
View file @
c28ad199
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
import
logging
from
typing
import
Any
,
Optional
import
numpy
import
torch
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.srt.layers.quantization.utils
import
pack_cols
,
unpack_cols
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.utils
import
get_device_capability
try
:
from
vllm
import
_custom_ops
as
ops
except
ImportError
:
ops
=
None
logger
=
logging
.
getLogger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT
=
True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
Optional
[
bool
]
=
None
,
include_fp_type
:
bool
=
True
,
device_capability
:
Optional
[
int
]
=
None
,
):
if
device_capability
is
None
:
major
,
minor
=
get_device_capability
()
capability
=
major
*
10
+
minor
device_capability
=
-
1
if
capability
is
None
else
capability
if
device_capability
<
80
:
return
[]
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if
has_zp
is
None
:
types0
=
query_marlin_supported_quant_types
(
False
,
include_fp_type
,
device_capability
)
types1
=
query_marlin_supported_quant_types
(
True
,
include_fp_type
,
device_capability
)
return
types0
+
types1
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
]
else
:
# GPTQ style, unsigned + symmetric bias
res
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
if
include_fp_type
:
res
+=
[
scalar_types
.
float8_e4m3fn
,
scalar_types
.
float4_e2m1f
]
return
res
def
_check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
device_capability
:
Optional
[
int
]
=
None
,
)
->
tuple
[
bool
,
Optional
[
str
]]:
if
device_capability
is
None
:
major
,
minor
=
get_device_capability
()
capability
=
major
*
10
+
minor
device_capability
=
-
1
if
capability
is
None
else
capability
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
True
,
device_capability
)
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
f
"Only types =
{
supported_types
}
"
f
"are supported (for group_size =
{
group_size
}
, "
f
"device_capability =
{
device_capability
}
, zp =
{
has_zp
}
)."
,
)
if
group_size
is
None
or
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
(
False
,
f
"Marlin does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
,
)
return
True
,
None
def
check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
device_capability
:
Optional
[
int
]
=
None
,
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
,
device_capability
)
return
cond
def
verify_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
,
)
->
None
:
# Validate output_size_per_partition
if
output_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_N
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
GPTQ_MARLIN_MIN_THREAD_N
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
GPTQ_MARLIN_MIN_THREAD_K
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def
check_marlin_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size
:
int
,
group_size
:
int
,
)
->
tuple
[
bool
,
Optional
[
str
]]:
try
:
verify_marlin_supports_shape
(
output_size_per_partition
,
input_size_per_partition
,
input_size
,
group_size
)
except
ValueError
as
e
:
return
False
,
e
.
__str__
()
return
True
,
None
def
check_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
->
bool
:
output_size_per_partition
=
(
getattr
(
layer
,
"output_size_per_partition"
,
None
)
or
layer
.
output_size
)
input_size_per_partition
=
(
getattr
(
layer
,
"input_size_per_partition"
,
None
)
or
layer
.
input_size
)
return
check_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
layer
.
input_size
,
group_size
=
group_size
,
)[
0
]
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
->
bool
:
hidden_size
=
layer
.
hidden_size
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight
=
not
layer
.
apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation
=
layer
.
activation
==
"silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape
=
(
hidden_size
%
128
==
0
and
intermediate_size_per_partition
%
max
(
64
,
group_size
)
==
0
)
supports_group_size
=
group_size
in
[
-
1
,
32
,
64
,
128
]
return
(
supports_shape
and
supports_group_size
and
supports_router_weight
and
supports_activation
)
def
marlin_make_workspace
(
device
:
torch
.
device
,
max_blocks_per_sm
:
int
=
1
)
->
torch
.
Tensor
:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
return
torch
.
zeros
(
sms
*
max_blocks_per_sm
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
marlin_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
def
marlin_repeat_scales_on_all_ranks
(
act_order
:
bool
,
group_size
:
int
,
is_row_parallel
:
bool
)
->
bool
:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise
=
group_size
==
-
1
return
act_order
or
(
is_channelwise
and
is_row_parallel
)
def
marlin_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
marlin_make_empty_zp
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
marlin_sort_g_idx
(
g_idx
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
def
get_scale_perms
():
scale_perm
:
list
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
list
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
([
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
scale_perm
,
scale_perm_single
=
get_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
return
output
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm
,
_
=
get_scale_perms
()
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
return
zp
def
awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp
=
unpack_cols
(
q_zp_packed
,
num_bits
,
size_k
,
size_n
)
# Undo interleaving (use argsort(..) to get inverse perm)
if
num_bits
==
4
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]))
elif
num_bits
==
8
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
1
,
3
]))
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_zp
=
q_zp
.
reshape
((
-
1
,
len
(
undo_interleave
)))[:,
undo_interleave
].
ravel
()
q_zp
=
q_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
marlin_zp
=
marlin_zero_points
(
q_zp
,
size_k
,
size_n
,
num_bits
)
return
marlin_zp
def
moe_awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
):
num_experts
=
q_zp_packed
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
q_zp_packed
.
shape
[
1
],
q_zp_packed
.
shape
[
2
]),
device
=
q_zp_packed
.
device
,
dtype
=
q_zp_packed
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
awq_to_marlin_zero_points
(
q_zp_packed
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
def
maybe_warn_marlin_atomic_add
(
device
,
dtype
):
if
torch
.
compiler
.
is_dynamo_compiling
():
return
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
logger
.
info_once
(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible."
)
def
maybe_warn_marlin_atomic_add_env
():
if
torch
.
compiler
.
is_dynamo_compiling
():
return
# TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if
True
:
return
# if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
# return
logger
.
info_once
(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
)
def
should_use_atomic_add_reduce
(
m
:
int
,
n
:
int
,
k
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
bool
:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if
n
>=
2048
or
k
<
2048
or
device
.
type
!=
"cuda"
:
return
False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
# TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if
not
True
:
maybe_warn_marlin_atomic_add_env
()
return
False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)
if
device_capability
[
0
]
<
9
and
dtype
==
torch
.
bfloat16
:
maybe_warn_marlin_atomic_add
(
device
,
dtype
)
return
False
return
True
def
apply_gptq_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
,
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
output_size_per_partition
,
k
=
reshaped_x
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
None
,
weight
,
weight_scale
,
None
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
,
wtype
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
apply_awq_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
,
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
output_size_per_partition
,
k
=
reshaped_x
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
None
,
weight
,
weight_scale
,
None
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
,
quant_type
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
class
MarlinConfig
(
QuantizationConfig
):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def
__init__
(
self
,
group_size
:
int
,
lm_head_quantized
:
bool
,
)
->
None
:
super
().
__init__
()
# Group size for the quantization.
self
.
group_size
=
group_size
self
.
lm_head_quantized
=
lm_head_quantized
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f
"
{
self
.
group_size
}
"
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
4
# Tile size used by marlin kernels.
self
.
tile_size
=
16
# Min out_features dim
self
.
min_n_threads
=
64
# Min in_features dim
self
.
min_k_threads
=
128
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
16
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
(
f
"MarlinConfig(group_size=
{
self
.
group_size
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
group_size
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format
=
hf_quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
hf_quant_cfg
.
get
(
"is_marlin_format"
,
False
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"marlin"
)
if
is_marlin_format
and
is_valid_user_quant
:
msg
=
"The model is serialized in {} format. Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()
)
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
MarlinLinearMethod
(
self
)
return
None
class
MarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
MarlinConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
,
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
,
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"B"
,
qweight
)
layer
.
register_parameter
(
"s"
,
scales
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B
=
torch
.
nn
.
Parameter
(
layer
.
B
.
data
,
requires_grad
=
False
)
layer
.
s
=
torch
.
nn
.
Parameter
(
layer
.
s
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
torch
.
nn
.
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
scales
=
layer
.
s
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
marlin_gemm
(
x_2d
,
qweight
,
scales
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
c28ad199
...
...
@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
np
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
np
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
np
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
class
MoeWNA16Config
(
QuantizationConfig
):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
...
...
python/sglang/srt/layers/quantization/quant_utils.py
deleted
100644 → 0
View file @
570d3343
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
typing
import
Optional
import
numpy
import
torch
from
sgl_kernel.scalar_type
import
ScalarType
def
get_pack_factor
(
num_bits
):
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
,
):
assert
(
quant_type
.
is_integer
()
),
"Floating point quantization may work but has not been tested"
assert
not
zero_points
or
group_size
is
not
None
,
(
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
if
group_size
==
-
1
:
group_size
=
size_k
# Reshape to [groupsize, -1]
if
group_size
is
not
None
and
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
w_s
=
torch
.
Tensor
([
1.0
]).
to
(
w
.
device
)
# unscaled case
maybe_w_zp
=
None
if
group_size
is
not
None
:
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
(
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
)
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)),
)
# Quantize
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
maybe_w_zp
is
not
None
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
is
not
None
and
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
maybe_w_zp
is
not
None
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
)
sgl-kernel/python/sgl_kernel
/scalar_type.py
→
python/sglang/srt/layers/quantization
/scalar_type.py
View file @
c28ad199
File moved
python/sglang/srt/layers/quantization/utils.py
View file @
c28ad199
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Tuple
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
import
numpy
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_npu
_is_cuda
=
is_cuda
()
...
...
@@ -143,3 +145,162 @@ def replace_parameter(
if
not
isinstance
(
new
,
torch
.
nn
.
Parameter
):
new
=
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
)
mod
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
))
def
get_pack_factor
(
num_bits
):
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
,
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
zero_points
:
bool
=
False
,
ref_zero_points_after_scales
:
bool
=
False
,
):
assert
(
quant_type
.
is_integer
()
),
"Floating point quantization may work but has not been tested"
assert
not
zero_points
or
group_size
is
not
None
,
(
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
if
group_size
==
-
1
:
group_size
=
size_k
# Reshape to [groupsize, -1]
if
group_size
is
not
None
and
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
w_s
=
torch
.
Tensor
([
1.0
]).
to
(
w
.
device
)
# unscaled case
maybe_w_zp
=
None
if
group_size
is
not
None
:
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
(
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
)
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)),
)
# Quantize
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if
ref_zero_points_after_scales
and
maybe_w_zp
is
not
None
:
w_ref
=
w_q
.
to
(
orig_type
)
*
w_s
-
maybe_w_zp
.
to
(
orig_type
)
*
w_s
else
:
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
is
not
None
and
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
maybe_w_zp
is
not
None
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
)
sgl-kernel/python/sgl_kernel/fused_moe.py
View file @
c28ad199
...
...
@@ -2,10 +2,11 @@ import functools
from
typing
import
Optional
import
torch
from
sgl_kernel.scalar_type
import
scalar_types
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
from
sglang.srt.layers.quantization.scalar_type
import
scalar_types
if
has_zp
:
assert
num_bits
==
4
return
scalar_types
.
uint4
...
...
sgl-kernel/tests/test_marlin_repack.py
View file @
c28ad199
import
math
import
numpy
as
np
import
pytest
import
torch
from
sgl_kernel
import
awq_marlin_repack
from
sgl_kernel.scalar_type
import
scalar_types
from
sglang.srt.layers.quantization.quant_utils
import
(
from
sglang.srt.layers.quantization.scalar_type
import
scalar_types
from
sglang.srt.layers.quantization.utils
import
(
get_pack_factor
,
pack_cols
,
quantize_weights
,
...
...
test/srt/test_gptqmodel_dynamic.py
View file @
c28ad199
...
...
@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
model_config
=
model_config
,
load_config
=
load_config
,
device_config
=
device_config
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.gptq
import
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
)
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
linear_method_cls
=
(
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
)
...
...
@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--dtype"
,
"float16"
],
other_args
=
[
"--dtype"
,
"
b
float16"
],
)
@
classmethod
...
...
test/srt/test_int4_kernel.py
deleted
100644 → 0
View file @
570d3343
import
itertools
import
sys
import
unittest
import
torch
sys
.
path
.
insert
(
0
,
"/home/hadoop-hmart-waimai-rank/vllm"
)
# from sglang.srt.layers.moe.topk import select_experts
from
sgl_kernel
import
fused_marlin_moe
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
# from vllm.model_executor.layers. import select_experts
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
,
)
from
vllm.scalar_type
import
scalar_types
def
stack_and_dev
(
tensors
:
list
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
expert_map
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
"Dimension mismatch"
assert
B
.
ndim
==
2
and
B
.
is_contiguous
(),
"B must be a 2D contiguous tensor"
# Reshape input
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
B
=
B
.
t
()
# Transpose weight matrix
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
K
,)
A
=
A
.
reshape
(
M
,
N
)
# As is per-token [M, 1], Bs is per-column [1, K]
C
=
torch
.
matmul
(
A
,
B
)
# [M, K]
C
=
As
*
C
*
Bs
.
view
(
1
,
-
1
)
# Broadcast per-column scale
return
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B
,
D
=
a
.
shape
# Perform per-token quantization
a_q
,
a_s
=
per_token_quant_int8
(
a
)
# Repeat tokens to match topk
a_q
=
a_q
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
# Also repeat the scale
a_s
=
a_s
.
view
(
B
,
-
1
,
1
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
1
)
# [B*topk, 1]
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
# Calculate routing
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
# Process each expert
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
# First MLP layer: note that a_s is now per-token
inter_out
=
native_w8a8_per_token_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
output_dtype
=
a
.
dtype
)
# Activation function
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
# Quantize activation output with per-token
act_out_q
,
act_out_s
=
per_token_quant_int8
(
act_out
)
# Second MLP layer
out
[
mask
]
=
native_w8a8_per_token_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
output_dtype
=
a
.
dtype
)
# Apply routing weights and sum
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
marlin_fused_moe
(
N
,
E
,
K
,
a
,
w1
,
w2
,
num_bits
,
group_size
,
act_order
,
score
,
topk
,
ep_size
):
quant_type
=
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
if
ep_size
>
1
:
local_e
=
E
//
ep_size
e_ids
=
torch
.
randperm
(
E
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
local_e
]
e_map
=
torch
.
full
((
E
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
s1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
=
K
)
quant_res
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
quant_res
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
=
N
)
quant_res
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
quant_res
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=a,
# router_logits=score,
# top_k=topk,
# num_expert_group=E,
# use_grouped_topk=False,
# renormalize=False,
# topk_group=None,
# )
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
marlin_output
=
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
global_num_experts
=
E
,
expert_map
=
e_map
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
num_bits
=
num_bits
,
is_k_full
=
True
,
)
return
marlin_output
,
torch_output
class
TestW8A8Int8FusedMoE
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
float16
]
M
=
[
1
,
16
]
N
=
[
128
]
K
=
[
256
]
E
=
[
4
,
10
]
TOP_KS
=
[
2
,
4
]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
NUM_BITS
=
[
4
]
EP_SIZE
=
[
1
,
4
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_w4a8_int8_fused_moe
(
self
,
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
,
num_bits
,
ep_size
):
torch
.
manual_seed
(
seed
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
# Generate int8 weights
w1_fp16
=
(
torch
.
rand
((
E
,
2
*
N
,
K
),
dtype
=
dtype
)
-
0.5
)
*
2
w2_fp16
=
(
torch
.
rand
((
E
,
K
,
N
),
dtype
=
dtype
)
-
0.5
)
*
2
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
torch
.
inference_mode
():
marlin_out
,
ref_out
=
marlin_fused_moe
(
N
=
N
,
E
=
E
,
K
=
K
,
a
=
a
,
w1
=
w1_fp16
,
w2
=
w2_fp16
,
num_bits
=
num_bits
,
group_size
=-
1
,
act_order
=
False
,
score
=
score
,
topk
=
topk
,
ep_size
=
ep_size
,
)
# Check results
if
(
torch
.
mean
(
torch
.
abs
(
marlin_out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
))
)
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
>
0.1
):
print
(
f
"marlin_out:
{
marlin_out
}
"
)
print
(
f
"ref_out:
{
ref_out
}
"
)
print
(
torch
.
mean
(
torch
.
abs
(
marlin_out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
))
)
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
)
torch
.
testing
.
assert_close
(
marlin_out
,
ref_out
,
atol
=
2e-2
,
rtol
=
0
)
def
test_w4a8_int8_fused_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
E
,
self
.
TOP_KS
,
self
.
BLOCK_SIZE
,
self
.
DTYPES
,
self
.
SEEDS
,
self
.
NUM_BITS
,
self
.
EP_SIZE
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
E
=
params
[
3
],
topk
=
params
[
4
],
block_size
=
params
[
5
],
dtype
=
params
[
6
],
seed
=
params
[
7
],
num_bits
=
params
[
8
],
ep_size
=
params
[
9
],
):
self
.
_w4a8_int8_fused_moe
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/test_w4a8.py
deleted
100644 → 0
View file @
570d3343
import
sgl_kernel
import
torch
x
=
torch
.
randn
(
10
,
10
,
device
=
"cuda"
)
qweight
=
torch
.
randn
(
10
,
10
,
device
=
"cuda"
)
s1_scales
=
torch
.
randn
(
10
,
device
=
"cuda"
)
input_scales
=
torch
.
randn
(
10
,
device
=
"cuda"
)
s1_szeros
=
torch
.
randn
(
10
,
device
=
"cuda"
)
input_sum
=
torch
.
randn
(
10
,
device
=
"cuda"
)
output_buffer
=
torch
.
randn
(
10
,
device
=
"cuda"
)
torch
.
ops
.
sgl_kernel
.
gemm_forward_cuda
.
default
(
x
,
qweight
,
s1_scales
,
input_scales
,
s1_szeros
,
input_sum
,
output_buffer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment