Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ae25d36d
Unverified
Commit
ae25d36d
authored
Mar 27, 2025
by
laixin
Committed by
GitHub
Mar 26, 2025
Browse files
[3/3] fix dsv3 awq issue (#4719)
Co-authored-by:
AniZpZ
<
aniz1905@gmail.com
>
parent
1099f6c9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
204 additions
and
4 deletions
+204
-4
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+4
-4
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+200
-0
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
ae25d36d
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
try
:
try
:
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
...
@@ -33,14 +32,15 @@ except ImportError:
...
@@ -33,14 +32,15 @@ except ImportError:
class
DummyConfig
:
class
DummyConfig
:
pass
pass
AQLMConfig
=
AWQConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
(
AQLMConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
(
CompressedTensors
Config
Dummy
Config
)
=
DummyConfig
)
DeepSpeedFPConfig
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
(
DeepSpeedFPConfig
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
(
GPTQMarlin24Config
GPTQMarlin24Config
)
=
DummyConfig
)
=
DummyConfig
MarlinConfig
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
MarlinConfig
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
...
...
python/sglang/srt/layers/quantization/awq.py
0 → 100644
View file @
ae25d36d
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
sgl_kernel
import
awq_dequantize
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
GroupQuantScaleParameter
,
PackedvLLMParameter
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
logger
=
logging
.
getLogger
(
__name__
)
def
is_layer_skipped_awq
(
prefix
:
str
,
modules_to_not_convert
:
List
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
modules_to_not_convert
)
class
AWQConfig
(
QuantizationConfig
):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
zero_point
=
zero_point
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"AWQ, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"zero_point=
{
self
.
zero_point
}
, "
f
"modules_to_not_convert=
{
self
.
modules_to_not_convert
}
)"
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
get_name
(
self
)
->
str
:
return
"awq"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The AWQ kernel only supports Turing or newer GPUs.
return
75
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
,
# E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"w_bit"
,
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"q_group_size"
,
"group_size"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
modules_to_not_convert
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
AWQLinearMethod
(
self
)
return
None
class
AWQLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"scales"
,
scales
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
awq_dequantize
(
qweight
,
scales
,
qzeros
)
out
=
torch
.
matmul
(
reshaped_x
,
out
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment