Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b2e8f76
Unverified
Commit
1b2e8f76
authored
May 24, 2025
by
HandH1998
Committed by
GitHub
May 23, 2025
Browse files
[2/2] Support Qserve (#6521)
parent
d2e0881a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
268 additions
and
5 deletions
+268
-5
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-0
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+18
-5
python/sglang/srt/layers/quantization/qoq.py
python/sglang/srt/layers/quantization/qoq.py
+244
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
1b2e8f76
...
...
@@ -349,6 +349,7 @@ class ModelConfig:
"w8a8_int8"
,
"w8a8_fp8"
,
"moe_wna16"
,
"qoq"
,
]
compatible_quantization_methods
=
{
"modelopt_fp4"
:
[
"modelopt"
],
...
...
@@ -458,6 +459,8 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
if
isinstance
(
config_dtype
,
str
):
config_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
.
get
(
config_dtype
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
1b2e8f76
...
...
@@ -67,6 +67,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config
,
)
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.qoq
import
QoQConfig
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
...
...
@@ -80,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w8a8_fp8"
:
W8A8Fp8Config
,
"moe_wna16"
:
MoeWNA16Config
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"qoq"
:
QoQConfig
,
}
# VLLM-dependent quantization methods
...
...
python/sglang/srt/layers/quantization/int8_kernel.py
View file @
1b2e8f76
...
...
@@ -22,9 +22,11 @@ def _per_token_quant_int8(
x_ptr
,
xq_ptr
,
scale_ptr
,
x_sum_ptr
,
stride_x
,
stride_xq
,
N
,
CAL_SUM
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
...
...
@@ -38,16 +40,23 @@ def _per_token_quant_int8(
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
tl
.
extra
.
cuda
.
libdevice
.
round
(
x_q
).
to
(
tl
.
int8
)
if
CAL_SUM
:
x_sum
=
tl
.
sum
(
x
,
axis
=
0
)
tl
.
store
(
x_sum_ptr
+
row_id
,
x_sum
.
to
(
x_sum_ptr
.
dtype
.
element_ty
))
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
.
to
(
scale_ptr
.
dtype
.
element_ty
)
)
def
per_token_quant_int8
(
x
):
def
per_token_quant_int8
(
x
,
scale_dtype
=
torch
.
float32
,
cal_sum
=
False
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
scale_dtype
)
if
cal_sum
:
x_sum
=
torch
.
empty
(
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
x_sum
=
None
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
...
...
@@ -57,15 +66,19 @@ def per_token_quant_int8(x):
x
,
x_q
,
scales
,
x_sum
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
CAL_SUM
=
cal_sum
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
if
cal_sum
:
return
x_q
,
scales
,
x_sum
else
:
return
x_q
,
scales
@
triton
.
jit
...
...
python/sglang/srt/layers/quantization/qoq.py
0 → 100644
View file @
1b2e8f76
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
ModelWeightParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
qserve_w4a8_per_chn_gemm
,
qserve_w4a8_per_group_gemm
QoQ_SUPPORTED_WEIGHT_BITS
=
[
4
]
QoQ_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
class
QoQConfig
(
QuantizationConfig
):
"""Config class for QoQ Quantization.
- Weight: static, per-channel/group, asymmetric
- Activation: dynamic, per-token, symmetric
Reference: https://arxiv.org/abs/2405.04532
https://github.com/mit-han-lab/omniserve
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
# Verify
if
self
.
weight_bits
not
in
QoQ_SUPPORTED_WEIGHT_BITS
:
raise
ValueError
(
f
"QoQ does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
QoQ_SUPPORTED_WEIGHT_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
QoQ_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"QoQ does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
QoQ_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
# 4 bits packed into 8 bit datatype.
self
.
pack_factor
=
8
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
"QoQConfig(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_name
(
self
)
->
str
:
return
"qoq"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"QoQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.linear
import
LinearBase
if
isinstance
(
layer
,
LinearBase
):
return
QoQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
QoQLinearMethod
(
LinearMethodBase
):
"""Linear method for QoQ.
Args:
quant_config: The QoQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
QoQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
32
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by 32."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
qweight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
s1_scales
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
dtype
=
torch
.
float16
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"s1_scales"
,
s1_scales
)
if
self
.
quant_config
.
group_size
==
-
1
:
s1_szeros
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
dtype
=
torch
.
float16
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"s1_szeros"
,
s1_szeros
)
else
:
s2_scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
),
dtype
=
torch
.
int8
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"s2_scales"
,
s2_scales
)
s2_zeros
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
),
dtype
=
torch
.
int8
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"s2_zeros"
,
s2_zeros
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
s1_scales
=
Parameter
(
layer
.
s1_scales
.
data
,
requires_grad
=
False
)
if
self
.
quant_config
.
group_size
==
-
1
:
layer
.
s1_szeros
=
Parameter
(
layer
.
s1_szeros
.
data
,
requires_grad
=
False
)
else
:
layer
.
s2_scales
=
Parameter
(
layer
.
s2_scales
.
data
,
requires_grad
=
False
)
layer
.
s2_zeros
=
Parameter
(
layer
.
s2_zeros
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
x
.
dtype
==
torch
.
float16
,
"QoQ only supports float16 input now"
if
self
.
quant_config
.
group_size
==
-
1
:
x_q
,
x_scale
,
x_sum
=
per_token_quant_int8
(
x
,
scale_dtype
=
x
.
dtype
,
cal_sum
=
True
)
out
=
qserve_w4a8_per_chn_gemm
(
x_q
,
layer
.
qweight
,
layer
.
s1_scales
,
x_scale
,
layer
.
s1_szeros
,
x_sum
)
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
,
scale_dtype
=
x
.
dtype
)
out
=
qserve_w4a8_per_group_gemm
(
x_q
,
layer
.
qweight
,
layer
.
s2_zeros
,
layer
.
s2_scales
,
layer
.
s1_scales
,
x_scale
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
python/sglang/srt/server_args.py
View file @
1b2e8f76
...
...
@@ -577,6 +577,7 @@ class ServerArgs:
"w8a8_int8"
,
"w8a8_fp8"
,
"moe_wna16"
,
"qoq"
,
],
help
=
"The quantization method."
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment