Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4645024d
Unverified
Commit
4645024d
authored
Aug 23, 2025
by
Isotr0py
Committed by
GitHub
Aug 22, 2025
Browse files
[Quantization] Allow GGUF quantization to skip unquantized layer (#23188)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
cd7a3df2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
3 deletions
+36
-3
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+11
-2
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+13
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+12
-0
No files found.
vllm/model_executor/layers/quantization/gguf.py
View file @
4645024d
...
...
@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEConfig
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
...
...
@@ -28,8 +29,10 @@ logger = init_logger(__name__)
class
GGUFConfig
(
QuantizationConfig
):
"""Config class for GGUF."""
def
__init__
(
self
,
)
->
None
:
def
__init__
(
self
,
unquantized_modules
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
unquantized_modules
=
unquantized_modules
or
[]
def
__repr__
(
self
)
->
str
:
return
(
"GGUFConfig()"
)
...
...
@@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_gguf
(
prefix
,
self
.
unquantized_modules
):
return
UnquantizedLinearMethod
()
return
GGUFLinearMethod
(
self
)
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
return
GGUFEmbeddingMethod
(
self
)
...
...
@@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
return
None
def
is_layer_skipped_gguf
(
prefix
:
str
,
unquantized_modules
:
list
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
unquantized_modules
)
UNQUANTIZED_TYPES
=
{
WeightType
.
F32
,
WeightType
.
F16
,
WeightType
.
BF16
}
STANDARD_QUANT_TYPES
=
{
WeightType
.
Q4_0
,
...
...
vllm/model_executor/model_loader/gguf_loader.py
View file @
4645024d
...
...
@@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
get_gguf_extra_tensor_names
,
gguf_quant_weights_iterator
)
get_gguf_extra_tensor_names
,
get_gguf_weight_type_map
,
gguf_quant_weights_iterator
)
class
GGUFModelLoader
(
BaseModelLoader
):
...
...
@@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
weight_type_map
=
get_gguf_weight_type_map
(
model_config
.
model
,
gguf_weights_map
)
# filter out unquantized modules to skip
unquant_names
=
[
name
.
removesuffix
(
".weight"
)
for
name
,
weight_type
in
weight_type_map
.
items
()
if
weight_type
==
"F32"
and
name
.
endswith
(
".weight"
)
]
vllm_config
.
quant_config
.
unquantized_modules
.
extend
(
unquant_names
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
4645024d
...
...
@@ -563,6 +563,18 @@ def get_gguf_extra_tensor_names(
return
[
gguf_to_hf_name_map
[
key
]
for
key
in
extra_keys
]
def
get_gguf_weight_type_map
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
dict
[
str
,
str
])
->
dict
[
str
,
str
]:
"""
Return GGUF mapped weight's name and its quant type
"""
reader
=
gguf
.
GGUFReader
(
gguf_file
)
return
{
gguf_to_hf_name_map
[
tensor
.
name
]:
tensor
.
tensor_type
.
name
for
tensor
in
reader
.
tensors
if
tensor
.
name
in
gguf_to_hf_name_map
}
def
gguf_quant_weights_iterator
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
dict
[
str
,
str
]
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment