Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8f0a9ca8
Unverified
Commit
8f0a9ca8
authored
Nov 04, 2024
by
Michael Goin
Committed by
GitHub
Nov 04, 2024
Browse files
[Bugfix] Respect modules_to_not_convert within awq_marlin (#9895)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
2094062b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
11 deletions
+24
-11
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+24
-11
No files found.
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
8f0a9ca8
...
...
@@ -9,7 +9,9 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.awq
import
is_layer_skipped_awq
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
...
...
@@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
8
:
scalar_types
.
uint8
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
zero_point
=
zero_point
self
.
lm_head_quantized
=
lm_head_quantized
self
.
weight_bits
=
weight_bits
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
if
self
.
weight_bits
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
self
.
weight_bits
}
. "
...
...
@@ -52,13 +59,14 @@ class AWQMarlinConfig(QuantizationConfig):
verify_marlin_supported
(
self
.
quant_type
,
group_size
=
self
.
group_size
,
has_zp
=
self
.
has_zp
)
has_zp
=
self
.
zero_point
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"has_zp=
{
self
.
has_zp
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
f
"zero_point=
{
self
.
zero_point
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"modules_to_not_convert=
{
self
.
modules_to_not_convert
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -80,10 +88,13 @@ class AWQMarlinConfig(QuantizationConfig):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
)
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
lm_head_quantized
,
modules_to_not_convert
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
...
@@ -109,6 +120,8 @@ class AWQMarlinConfig(QuantizationConfig):
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
AWQMoEMethod
(
self
)
...
...
@@ -123,7 +136,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
has_zp
=
quant_config
.
get
(
"zero_point"
)
zero_point
=
quant_config
.
get
(
"zero_point"
)
if
not
current_platform
.
is_cuda
():
return
False
...
...
@@ -132,7 +145,7 @@ class AWQMarlinConfig(QuantizationConfig):
return
False
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
has_zp
is
None
):
if
(
num_bits
is
None
or
group_size
is
None
or
zero_point
is
None
):
return
False
if
num_bits
not
in
cls
.
TYPE_MAP
:
...
...
@@ -140,7 +153,7 @@ class AWQMarlinConfig(QuantizationConfig):
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[
num_bits
],
group_size
=
group_size
,
has_zp
=
has_zp
)
has_zp
=
zero_point
)
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment