Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec82c3e3
Unverified
Commit
ec82c3e3
authored
May 24, 2025
by
Wenhua Cheng
Committed by
GitHub
May 23, 2025
Browse files
FIX MOE issue in AutoRound format (#18586)
Signed-off-by:
wenhuach21
<
wenhua.cheng@intel.com
>
parent
45ab403a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
28 deletions
+30
-28
README.md
README.md
+1
-1
vllm/model_executor/layers/quantization/auto_round.py
vllm/model_executor/layers/quantization/auto_round.py
+29
-27
No files found.
README.md
View file @
ec82c3e3
...
...
@@ -58,7 +58,7 @@ vLLM is fast with:
-
Efficient management of attention key and value memory with
[
**PagedAttention**
](
https://blog.vllm.ai/2023/06/20/vllm.html
)
-
Continuous batching of incoming requests
-
Fast model execution with CUDA/HIP graph
-
Quantizations:
[
GPTQ
](
https://arxiv.org/abs/2210.17323
)
,
[
AWQ
](
https://arxiv.org/abs/2306.00978
)
, INT4, INT8, and FP8.
-
Quantizations:
[
GPTQ
](
https://arxiv.org/abs/2210.17323
)
,
[
AWQ
](
https://arxiv.org/abs/2306.00978
)
,
[
AutoRound
](
https://arxiv.org/abs/2309.05516
)
,
INT4, INT8, and FP8.
-
Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
-
Speculative decoding
-
Chunked prefill
...
...
vllm/model_executor/layers/quantization/auto_round.py
View file @
ec82c3e3
...
...
@@ -8,6 +8,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
...
...
@@ -74,7 +75,7 @@ class AutoRoundConfig(QuantizationConfig):
f
"group_size=
{
self
.
group_size
}
, sym=
{
self
.
sym
}
)"
)
@
classmethod
def
get_name
(
cls
)
:
## use str will trigger preci issue
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"auto-round"
@
classmethod
...
...
@@ -142,18 +143,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix
,
layer
.
__class__
.
__name__
,
weight_bits
,
group_size
,
sym
)
if
backend
==
"auto"
or
"marlin"
in
backend
:
AWQ_TYPE_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
,
}
use_marlin
=
(
weight_bits
in
AWQ_TYPE_MAP
)
and
check_marlin_supported
(
AWQ_TYPE_MAP
[
weight_bits
],
group_size
,
not
sym
)
if
isinstance
(
layer
,
FusedMoE
):
use_marlin
=
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
use_marlin
=
use_marlin
and
check_moe_marlin_supports_layer
(
layer
,
group_size
)
AWQ_TYPE_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
AWQ_TYPE_MAP
and
check_marlin_supported
(
AWQ_TYPE_MAP
[(
weight_bits
)],
group_size
,
not
sym
))
else
:
use_marlin
=
False
if
use_marlin
:
...
...
@@ -180,10 +181,11 @@ class AutoRoundConfig(QuantizationConfig):
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
config
=
{
"
linear_
quant_method"
:
"awq"
,
"
weight_
bits"
:
weight_bits
,
"quant_method"
:
"awq"
,
"bits"
:
weight_bits
,
"group_size"
:
group_size
,
"zero_point"
:
not
sym
,
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
...
...
@@ -213,18 +215,18 @@ class AutoRoundConfig(QuantizationConfig):
prefix
,
layer
.
__class__
.
__name__
,
weight_bits
,
group_size
,
sym
)
if
backend
==
"auto"
or
"marlin"
in
backend
:
GPTQ_TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
GPTQ_TYPE_MAP
and
check_marlin_supported
(
GPTQ_TYPE_MAP
[(
weight_bits
,
sym
)],
group_size
,
has_zp
=
not
sym
))
if
isinstance
(
layer
,
FusedMoE
):
use_marlin
=
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
GPTQ_TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
use_marlin
=
((
weight_bits
,
sym
)
in
GPTQ_TYPE_MAP
and
check_marlin_supported
(
GPTQ_TYPE_MAP
[(
weight_bits
,
sym
)],
group_size
,
has_zp
=
not
sym
))
use_marlin
=
use_marlin
and
check_moe_marlin_supports_layer
(
layer
,
group_size
)
else
:
use_marlin
=
False
if
use_marlin
:
...
...
@@ -251,11 +253,11 @@ class AutoRoundConfig(QuantizationConfig):
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
config
=
{
"
linear_
quant_method"
:
"gptq"
,
"
weight_
bits"
:
weight_bits
,
"quant_method"
:
"gptq"
,
"bits"
:
weight_bits
,
"group_size"
:
group_size
,
"sym"
:
sym
,
"lm_head
_quantized
"
:
False
,
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment