Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3654847d
Unverified
Commit
3654847d
authored
Aug 02, 2025
by
JartX
Committed by
GitHub
Aug 01, 2025
Browse files
feat: Add Support GPTQ Quantization MOE on ROCM vllm serve (#21733)
parent
eefbf4a6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
5 deletions
+21
-5
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-2
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+19
-3
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
3654847d
...
...
@@ -761,8 +761,8 @@ def get_moe_wna16_block_config(config: dict[str,
def
should_moe_wna16_use_cuda
(
num_valid_tokens
:
int
,
group_size
:
int
,
num_experts
:
int
,
bit
:
int
):
return
bit
==
4
and
group_size
in
[
32
,
64
,
128
]
and
\
num_valid_tokens
/
num_experts
<=
6
return
current_platform
.
is_cuda
()
and
bit
==
4
and
\
group_size
in
[
32
,
64
,
128
]
and
num_valid_tokens
/
num_experts
<=
6
def
get_default_config
(
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
3654847d
...
...
@@ -10,10 +10,11 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
...
...
@@ -110,8 +111,23 @@ class GPTQConfig(QuantizationConfig):
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
,
dynamic
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"GPTQLinearMethod"
,
"QuantizeMethodBase"
]]:
if
isinstance
(
layer
,
FusedMoE
):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from
.moe_wna16
import
MoeWNA16Config
config
=
{
"quant_method"
:
"gptq"
,
"bits"
:
self
.
weight_bits
,
"group_size"
:
self
.
group_size
,
"sym"
:
True
,
# GPTQ typically uses symmetric quantization
"lm_head"
:
False
,
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment