Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
73e2e011
Unverified
Commit
73e2e011
authored
Jun 12, 2025
by
Jee Jee Li
Committed by
GitHub
Jun 12, 2025
Browse files
[Quantization] Improve AWQ logic (#19431)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
c9280e63
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
4 deletions
+38
-4
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+38
-4
No files found.
vllm/model_executor/layers/quantization/awq.py
View file @
73e2e011
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Union
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
logger
=
init_logger
(
__name__
)
class
AWQConfig
(
QuantizationConfig
):
class
AWQConfig
(
QuantizationConfig
):
"""Config class for AWQ.
"""Config class for AWQ.
...
@@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig):
...
@@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig):
config
,
[
"modules_to_not_convert"
],
None
)
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
modules_to_not_convert
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
modules_to_not_convert
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
prefix
:
str
)
->
Optional
[
"LinearMethodBase"
]:
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"LinearMethodBase"
,
"QuantizeMethodBase"
]]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
AWQLinearMethod
(
self
)
return
AWQLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
# Lazy import to avoid circular import.
from
.awq_marlin
import
AWQMarlinConfig
,
AWQMoEMethod
from
.moe_wna16
import
MoeWNA16Config
from
.utils.marlin_utils
import
check_moe_marlin_supports_layer
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_once
(
f
"Layer '
{
prefix
}
' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
config
=
{
"quant_method"
:
"awq"
,
"bits"
:
self
.
weight_bits
,
"group_size"
:
self
.
group_size
,
"zero_point"
:
self
.
zero_point
,
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
marlin_compatible_config_dict
=
{
"quant_method"
:
"awq"
,
"bits"
:
self
.
weight_bits
,
"group_size"
:
self
.
group_size
,
"zero_point"
:
self
.
zero_point
,
"lm_head"
:
False
,
"modules_to_not_convert"
:
self
.
modules_to_not_convert
,
}
awq_marlin_config
=
AWQMarlinConfig
.
from_config
(
marlin_compatible_config_dict
)
return
AWQMoEMethod
(
awq_marlin_config
)
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment