Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e5c8e09
Unverified
Commit
5e5c8e09
authored
Feb 14, 2025
by
Michael Goin
Committed by
GitHub
Feb 14, 2025
Browse files
[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many experts (#13236)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
c9e2d644
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
26 deletions
+39
-26
tests/weight_loading/test_weight_loading.py
tests/weight_loading/test_weight_loading.py
+1
-1
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+16
-19
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+15
-5
No files found.
tests/weight_loading/test_weight_loading.py
View file @
5e5c8e09
...
...
@@ -12,7 +12,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq"
)
REVISION
=
os
.
environ
.
get
(
"REVISION"
,
"main"
)
QUANTIZATION
=
os
.
environ
.
get
(
"QUANTIZATION"
,
"gptq_marlin"
)
MIN_CAPABILITY
=
os
.
environ
.
get
(
"MIN_CAPABILITY"
,
"8
9
"
)
MIN_CAPABILITY
=
os
.
environ
.
get
(
"MIN_CAPABILITY"
,
"8
0
"
)
@
pytest
.
mark
.
skipif
(
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
5e5c8e09
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.moe_wna16
import
MoeWNA16Config
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
...
...
@@ -134,7 +135,12 @@ class AWQMarlinConfig(QuantizationConfig):
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
AWQMoEMethod
(
self
)
if
layer
.
num_experts
>
32
:
# For MoEs with many experts the moe_wna16 kernel is faster
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
else
:
return
AWQMoEMethod
(
self
)
return
None
@
classmethod
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
5e5c8e09
...
...
@@ -10,20 +10,18 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
UnquantizedLinearMethod
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.moe_wna16
import
MoeWNA16Config
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
UnquantizedEmbeddingMethod
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
...
...
@@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig):
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
Dict
[
str
,
Dict
[
str
,
Union
[
int
,
bool
]]],
)
->
None
:
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
Dict
[
str
,
Dict
[
str
,
Union
[
int
,
bool
]]],
full_config
:
Dict
[
str
,
Any
])
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
...
...
@@ -90,6 +83,7 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
lm_head_quantized
=
lm_head_quantized
self
.
full_config
=
full_config
if
(
weight_bits
,
is_sym
)
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
...
...
@@ -132,7 +126,7 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
lm_head_quantized
,
dynamic
)
lm_head_quantized
,
dynamic
,
config
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
...
@@ -155,12 +149,15 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
,
UnquantizedLinearMethod
,
UnquantizedEmbeddingMethod
]]:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
if
layer
.
num_experts
>
32
:
# For MoEs with many experts the moe_wna16 kernel is faster
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
else
:
return
GPTQMarlinMoEMethod
(
self
)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinLinearMethod
)
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
5e5c8e09
...
...
@@ -9,13 +9,8 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supports_layer
)
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
@@ -37,6 +32,12 @@ class MoeWNA16Config(QuantizationConfig):
self
.
linear_quant_method
=
linear_quant_method
self
.
full_config
=
full_config
self
.
use_marlin
=
False
# Avoid circular import
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
if
self
.
linear_quant_method
==
"gptq"
:
self
.
use_marlin
=
GPTQMarlinConfig
.
is_gptq_marlin_compatible
(
full_config
)
...
...
@@ -115,6 +116,8 @@ class MoeWNA16Config(QuantizationConfig):
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
# Avoid circular import
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
awq_min_capability
=
AWQConfig
.
get_min_capability
()
gptq_compatible
=
quant_method
==
"gptq"
and
\
...
...
@@ -129,6 +132,13 @@ class MoeWNA16Config(QuantizationConfig):
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
LinearBase
):
# Avoid circular import
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
if
self
.
linear_quant_method
==
"gptq"
:
if
self
.
use_marlin
:
return
GPTQMarlinConfig
.
from_config
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment