Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9025a9a7
Unverified
Commit
9025a9a7
authored
Jul 01, 2025
by
Kyle Sayers
Committed by
GitHub
Jul 01, 2025
Browse files
[Quant] [Bugfix] Fix quantization config matching with `hf_to_vllm_mapper` (#20046)
parent
c05596f1
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
107 additions
and
29 deletions
+107
-29
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+1
-0
vllm/lora/models.py
vllm/lora/models.py
+1
-1
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+1
-4
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+13
-0
vllm/model_executor/layers/quantization/bitblas.py
vllm/model_executor/layers/quantization/bitblas.py
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+16
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+9
-1
vllm/model_executor/layers/quantization/gptq_bitblas.py
vllm/model_executor/layers/quantization/gptq_bitblas.py
+1
-0
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+2
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+1
-0
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+1
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+13
-9
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+20
-3
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+7
-7
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+1
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+14
-1
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+5
-2
No files found.
tests/quantization/test_register_quantization_config.py
View file @
9025a9a7
...
...
@@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig):
def
__init__
(
self
,
num_bits
:
int
=
8
)
->
None
:
"""Initialize the quantization config."""
super
().
__init__
()
self
.
num_bits
=
num_bits
def
get_name
(
self
)
->
QuantizationMethods
:
...
...
vllm/lora/models.py
View file @
9025a9a7
...
...
@@ -805,7 +805,7 @@ def create_lora_manager(
lora_manager_cls
:
type
[
LoRAModelManager
]
=
LoRAModelManager
,
**
kwargs
)
->
LoRAModelManager
:
"""Create a LoRA adapter for a given model."""
if
not
hasattr
(
model
,
"packed_modules_mapping"
):
if
not
isinstance
(
model
,
SupportsLoRA
):
raise
ValueError
(
f
"Model
{
type
(
model
)
}
is not supported for LoRA."
)
lora_manager
=
lora_manager_cls
(
model
=
model
,
...
...
vllm/lora/worker_manager.py
View file @
9025a9a7
...
...
@@ -111,10 +111,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model
=
self
.
_adapter_manager
.
model
hf_to_vllm_mapper
=
None
if
(
hasattr
(
model
,
"hf_to_vllm_mapper"
)
and
model
.
hf_to_vllm_mapper
is
not
None
):
hf_to_vllm_mapper
=
model
.
hf_to_vllm_mapper
hf_to_vllm_mapper
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
)
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_path
,
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
9025a9a7
...
...
@@ -10,6 +10,7 @@ from torch import nn
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.models.utils
import
WeightsMapper
else
:
QuantizationMethods
=
str
...
...
@@ -149,3 +150,15 @@ class QuantizationConfig(ABC):
def
get_cache_scale
(
self
,
name
:
str
)
->
Optional
[
str
]:
return
None
def
apply_vllm_mapper
(
# noqa: B027
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
vllm/model_executor/layers/quantization/bitblas.py
View file @
9025a9a7
...
...
@@ -63,6 +63,7 @@ class BitBLASConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act
=
False
super
().
__init__
()
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
9025a9a7
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
suppress
from
typing
import
Any
,
Literal
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
cast
import
torch
from
compressed_tensors.config
import
(
CompressionFormat
,
...
...
@@ -37,6 +37,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
cutlass_fp4_supported
)
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
...
@@ -80,6 +83,18 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_name
(
self
)
->
QuantizationMethods
:
return
"compressed-tensors"
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
self
.
target_scheme_map
=
hf_to_vllm_mapper
.
apply_dict
(
self
.
target_scheme_map
)
self
.
ignore
=
hf_to_vllm_mapper
.
apply_list
(
self
.
ignore
)
self
.
sparsity_scheme_map
=
hf_to_vllm_mapper
.
apply_dict
(
self
.
sparsity_scheme_map
)
self
.
sparsity_ignore_list
=
hf_to_vllm_mapper
.
apply_list
(
self
.
sparsity_ignore_list
)
if
self
.
kv_cache_scheme
is
not
None
:
self
.
kv_cache_scheme
=
hf_to_vllm_mapper
.
apply_dict
(
self
.
kv_cache_scheme
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
9025a9a7
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -39,6 +39,9 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
has_deep_gemm
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
init_logger
(
__name__
)
...
...
@@ -100,6 +103,11 @@ class Fp8Config(QuantizationConfig):
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[]
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
if
self
.
ignored_layers
is
not
None
:
self
.
ignored_layers
=
hf_to_vllm_mapper
.
apply_list
(
self
.
ignored_layers
)
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"Fp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
...
...
vllm/model_executor/layers/quantization/gptq_bitblas.py
View file @
9025a9a7
...
...
@@ -81,6 +81,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act
=
False
super
().
__init__
()
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
...
...
vllm/model_executor/layers/quantization/marlin.py
View file @
9025a9a7
...
...
@@ -32,6 +32,8 @@ class MarlinConfig(QuantizationConfig):
group_size
:
int
,
lm_head_quantized
:
bool
,
)
->
None
:
super
().
__init__
()
# Group size for the quantization.
self
.
group_size
=
group_size
self
.
lm_head_quantized
=
lm_head_quantized
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
9025a9a7
...
...
@@ -181,6 +181,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules
:
list
[
str
],
group_size
:
int
=
16
,
)
->
None
:
super
().
__init__
()
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
9025a9a7
...
...
@@ -55,6 +55,7 @@ class TorchAOConfig(QuantizationConfig):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
super
().
__init__
()
self
.
torchao_config
=
torchao_config
self
.
skip_modules
=
skip_modules
or
[]
...
...
vllm/model_executor/model_loader/utils.py
View file @
9025a9a7
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry
from
vllm.model_executor.models.adapters
import
(
as_classification_model
,
as_embedding_model
,
as_reward_model
)
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
@@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
packed_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
None
)
if
packed_mapping
is
not
None
:
# pass packed_modules_mapping by reference to quant_config
quant_config
.
packed_modules_mapping
=
packed_mapping
else
:
logger
.
warning
(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules"
,
model_class
.
__name__
)
if
not
issubclass
(
model_class
,
SupportsQuant
):
hf_to_vllm_mapper
=
getattr
(
model_class
,
"hf_to_vllm_mapper"
,
None
)
packed_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
None
)
# pass mappings by reference to quant_config
if
hf_to_vllm_mapper
is
not
None
:
quant_config
.
apply_vllm_mapper
(
hf_to_vllm_mapper
)
if
packed_mapping
is
not
None
:
quant_config
.
packed_modules_mapping
=
packed_mapping
vllm/model_executor/models/interfaces.py
View file @
9025a9a7
...
...
@@ -18,6 +18,7 @@ from .interfaces_base import is_pooling_model
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
...
...
@@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool:
class
SupportsQuant
:
"""The interface required for all models that support quantization."""
packed_modules_mapping
:
ClassVar
[
dict
[
str
,
list
[
str
]]]
=
{}
hf_to_vllm_mapper
:
ClassVar
[
Optional
[
"WeightsMapper"
]]
=
None
packed_modules_mapping
:
ClassVar
[
Optional
[
dict
[
str
,
list
[
str
]]]]
=
None
quant_config
:
Optional
[
QuantizationConfig
]
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
)
->
Self
:
instance
=
super
().
__new__
(
cls
)
# find config passed in arguments
quant_config
=
cls
.
_find_quant_config
(
*
args
,
**
kwargs
)
if
quant_config
is
not
None
:
# attach config to model for general use
instance
.
quant_config
=
quant_config
instance
.
quant_config
.
packed_modules_mapping
.
update
(
cls
.
packed_modules_mapping
)
# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if
getattr
(
instance
,
"hf_to_vllm_mapper"
,
None
)
is
not
None
:
instance
.
quant_config
.
apply_vllm_mapper
(
instance
.
hf_to_vllm_mapper
)
if
getattr
(
instance
,
"packed_modules_mapping"
,
None
)
is
not
None
:
instance
.
quant_config
.
packed_modules_mapping
.
update
(
instance
.
packed_modules_mapping
)
return
instance
@
staticmethod
def
_find_quant_config
(
*
args
,
**
kwargs
)
->
Optional
[
QuantizationConfig
]:
"""Find quant config passed through model constructor args"""
from
vllm.config
import
VllmConfig
# avoid circular import
args_values
=
list
(
args
)
+
list
(
kwargs
.
values
())
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
9025a9a7
...
...
@@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
from
.qwen2_vl
import
Qwen2VLDummyInputsBuilder
as
Qwen2_5_VLDummyInputsBuilder
from
.qwen2_vl
import
(
Qwen2VLMultiModalProcessor
,
Qwen2VLProcessingInfo
,
apply_rotary_pos_emb_vision
)
...
...
@@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
info
=
Qwen2_5_VLProcessingInfo
,
dummy_inputs
=
Qwen2_5_VLDummyInputsBuilder
)
class
Qwen2_5_VLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
SupportsLoRA
,
SupportsPP
,
SupportsQuant
):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper
=
WeightsMapper
(
...
...
@@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
:
Qwen2_5_VLConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
...
...
@@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
visual
=
Qwen2_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
self
.
quant_config
),
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
...
...
@@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_maybe_ignore_quant_config
(
self
,
quant_
config
:
QuantizationConfig
):
def
_maybe_ignore_quant_config
(
self
,
config
:
Optional
[
QuantizationConfig
]
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if
isinstance
(
quant_
config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
if
isinstance
(
config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_
config
return
config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/transformers.py
View file @
9025a9a7
...
...
@@ -467,6 +467,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@
property
def
hf_to_vllm_mapper
(
self
):
prefix_mapper
=
{
...
...
vllm/model_executor/models/utils.py
View file @
9025a9a7
...
...
@@ -4,7 +4,7 @@
import
itertools
from
collections.abc
import
Iterable
,
Mapping
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
import
torch
import
torch.nn
as
nn
...
...
@@ -64,6 +64,19 @@ class WeightsMapper:
return
((
out_name
,
data
)
for
name
,
data
in
weights
if
(
out_name
:
=
self
.
_map_name
(
name
))
is
not
None
)
def
apply_list
(
self
,
values
:
list
[
str
])
->
list
[
str
]:
return
[
out_name
for
name
in
values
if
(
out_name
:
=
self
.
_map_name
(
name
))
is
not
None
]
def
apply_dict
(
self
,
values
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
return
{
out_name
:
value
for
name
,
value
in
values
.
items
()
if
(
out_name
:
=
self
.
_map_name
(
name
))
is
not
None
}
class
AutoWeightsLoader
:
"""
...
...
vllm/model_executor/utils.py
View file @
9025a9a7
...
...
@@ -58,7 +58,8 @@ def _make_synced_weight_loader(original_weight_loader):
def
get_packed_modules_mapping
(
model
:
torch
.
nn
.
Module
)
->
dict
[
str
,
list
[
str
]]:
parent_map
=
copy
.
deepcopy
(
getattr
(
model
,
"packed_modules_mapping"
,
{}))
parent_map
=
getattr
(
model
,
"packed_modules_mapping"
,
None
)
parent_map
=
copy
.
deepcopy
(
parent_map
)
if
parent_map
is
not
None
else
{}
# don't infer mapping if the model has defined it explicitly.
if
parent_map
:
...
...
@@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
# We only check main components instead of whole model submodules
for
child
in
model
.
children
():
child_map
=
getattr
(
child
,
"packed_modules_mapping"
,
{})
child_map
=
getattr
(
child
,
"packed_modules_mapping"
,
None
)
child_map
=
copy
.
deepcopy
(
child_map
)
if
child_map
is
not
None
else
{}
if
any
((
k
in
parent_map
and
parent_map
[
k
]
!=
v
)
for
k
,
v
in
child_map
.
items
()):
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment