Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ae25d36d
Unverified
Commit
ae25d36d
authored
Mar 27, 2025
by
laixin
Committed by
GitHub
Mar 26, 2025
Browse files
[3/3] fix dsv3 awq issue (#4719)
Co-authored-by:
AniZpZ
<
aniz1905@gmail.com
>
parent
1099f6c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
204 additions
and
4 deletions
+204
-4
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+4
-4
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+200
-0
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
ae25d36d
...
...
@@ -9,7 +9,6 @@ import torch
try
:
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
...
...
@@ -33,14 +32,15 @@ except ImportError:
class
DummyConfig
:
pass
AQLMConfig
=
AWQConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
(
CompressedTensors
Config
)
=
DummyConfig
AQLMConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
(
Dummy
Config
)
DeepSpeedFPConfig
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
(
GPTQMarlin24Config
)
=
DummyConfig
MarlinConfig
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
...
...
python/sglang/srt/layers/quantization/awq.py
0 → 100644
View file @
ae25d36d
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
sgl_kernel
import
awq_dequantize
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
GroupQuantScaleParameter
,
PackedvLLMParameter
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
logger
=
logging
.
getLogger
(
__name__
)
def
is_layer_skipped_awq
(
prefix
:
str
,
modules_to_not_convert
:
List
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
modules_to_not_convert
)
class
AWQConfig
(
QuantizationConfig
):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
zero_point
=
zero_point
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"AWQ, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"zero_point=
{
self
.
zero_point
}
, "
f
"modules_to_not_convert=
{
self
.
modules_to_not_convert
}
)"
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
get_name
(
self
)
->
str
:
return
"awq"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The AWQ kernel only supports Turing or newer GPUs.
return
75
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
,
# E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
modules_to_not_convert
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
AWQLinearMethod
(
self
)
return
None
class
AWQLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"scales"
,
scales
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
awq_dequantize
(
qweight
,
scales
,
qzeros
)
out
=
torch
.
matmul
(
reshaped_x
,
out
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment