Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09972e71
Unverified
Commit
09972e71
authored
Feb 12, 2025
by
Michael Goin
Committed by
GitHub
Feb 12, 2025
Browse files
[Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity (#13119)
parent
36a08630
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
29 deletions
+58
-29
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+19
-16
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+18
-10
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+6
-3
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+15
-0
No files found.
vllm/model_executor/layers/linear.py
View file @
09972e71
...
@@ -290,29 +290,30 @@ class ColumnParallelLinear(LinearBase):
...
@@ -290,29 +290,30 @@ class ColumnParallelLinear(LinearBase):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_method
is
not
Non
e
self
.
input_size_per_partition
=
input_siz
e
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
divide
(
output_size
,
self
.
tp_size
)
for
output_size
in
self
.
output_sizes
for
output_size
in
self
.
output_sizes
]
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
if
output_sizes
is
None
:
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
output_sizes
=
[
output_size
]
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
.
quant_method
.
create_weights
(
layer
=
self
,
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
input_size_per_partition
=
self
.
input_size
_per_partition
,
output_partition_sizes
=
self
.
output_partition_sizes
,
output_partition_sizes
=
self
.
output_partition_sizes
,
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
...
@@ -1044,22 +1045,24 @@ class RowParallelLinear(LinearBase):
...
@@ -1044,22 +1045,24 @@ class RowParallelLinear(LinearBase):
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
):
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
.
quant_method
.
create_weights
(
layer
=
self
,
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
input_size_per_partition
=
self
.
input_size_per_partition
,
output_partition_sizes
=
[
self
.
output_size
]
,
output_partition_sizes
=
self
.
output_
partition_
size
s
,
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
09972e71
...
@@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import (
...
@@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import (
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.awq
import
is_layer_skipped_awq
from
vllm.model_executor.layers.quantization.awq
import
(
AWQConfig
,
is_layer_skipped_awq
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
check_marlin_supports_layer
,
marlin_make_empty_g_idx
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
...
@@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
8
:
scalar_types
.
uint8
,
8
:
scalar_types
.
uint8
,
}
}
def
__init__
(
self
,
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
lm_head_quantized
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
modules_to_not_convert
:
Optional
[
List
[
str
]],
full_config
:
Dict
[
str
,
Any
])
->
None
:
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
zero_point
=
zero_point
self
.
zero_point
=
zero_point
self
.
lm_head_quantized
=
lm_head_quantized
self
.
lm_head_quantized
=
lm_head_quantized
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
self
.
full_config
=
full_config
if
self
.
weight_bits
not
in
self
.
TYPE_MAP
:
if
self
.
weight_bits
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
self
.
weight_bits
}
. "
raise
ValueError
(
f
"Unsupported num_bits =
{
self
.
weight_bits
}
. "
...
@@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert
=
cls
.
get_from_keys_or
(
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
lm_head_quantized
,
return
cls
(
weight_bits
,
group_size
,
zero_point
,
lm_head_quantized
,
modules_to_not_convert
)
modules_to_not_convert
,
config
)
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
@@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
# Check if the layer is supported by AWQMarlin.
if
not
check_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_once
(
f
"Layer '
{
prefix
}
' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels."
)
return
AWQConfig
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
AWQMoEMethod
(
self
)
return
AWQMoEMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
09972e71
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supports_layer
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -87,8 +89,8 @@ class MoeWNA16Config(QuantizationConfig):
...
@@ -87,8 +89,8 @@ class MoeWNA16Config(QuantizationConfig):
modules_to_not_convert
=
[]
modules_to_not_convert
=
[]
elif
linear_quant_method
==
"awq"
:
elif
linear_quant_method
==
"awq"
:
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
modules_to_not_convert
=
cls
.
get_from_keys
(
modules_to_not_convert
=
cls
.
get_from_keys
_or
(
config
,
[
"modules_to_not_convert"
])
config
,
[
"modules_to_not_convert"
]
,
None
)
else
:
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
...
@@ -135,7 +137,8 @@ class MoeWNA16Config(QuantizationConfig):
...
@@ -135,7 +137,8 @@ class MoeWNA16Config(QuantizationConfig):
return
GPTQConfig
.
from_config
(
return
GPTQConfig
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
elif
self
.
linear_quant_method
==
"awq"
:
elif
self
.
linear_quant_method
==
"awq"
:
if
self
.
use_marlin
:
if
self
.
use_marlin
and
check_marlin_supports_layer
(
layer
,
self
.
group_size
):
return
AWQMarlinConfig
.
from_config
(
return
AWQMarlinConfig
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
else
:
else
:
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
09972e71
...
@@ -6,6 +6,7 @@ import numpy
...
@@ -6,6 +6,7 @@ import numpy
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
@@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
...
@@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return
True
,
None
return
True
,
None
def
check_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
->
bool
:
output_size_per_partition
=
getattr
(
layer
,
"output_size_per_partition"
,
None
)
or
layer
.
output_size
input_size_per_partition
=
getattr
(
layer
,
"input_size_per_partition"
,
None
)
or
layer
.
input_size
return
check_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
layer
.
input_size
,
group_size
=
group_size
)[
0
]
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
device
:
torch
.
device
)
->
torch
.
Tensor
:
max_workspace_size
=
(
output_size_per_partition
//
max_workspace_size
=
(
output_size_per_partition
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment