Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d4d98998
Unverified
Commit
d4d98998
authored
Sep 26, 2025
by
Isotr0py
Committed by
GitHub
Sep 26, 2025
Browse files
[Quantization] Add field to skip unquantized modules for GPTQ config (#25455)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
db1e42f6
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
219 additions
and
153 deletions
+219
-153
vllm/config/__init__.py
vllm/config/__init__.py
+1
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+6
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+46
-6
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+55
-10
vllm/model_executor/layers/quantization/utils/gptq_utils.py
vllm/model_executor/layers/quantization/utils/gptq_utils.py
+51
-1
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+2
-10
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+1
-35
vllm/model_executor/models/ovis.py
vllm/model_executor/models/ovis.py
+1
-12
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+1
-12
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-12
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+6
-22
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+5
-19
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+1
-11
vllm/model_executor/models/qwen3_vl_moe.py
vllm/model_executor/models/qwen3_vl_moe.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+31
-1
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+10
-1
No files found.
vllm/config/__init__.py
View file @
d4d98998
...
...
@@ -270,6 +270,7 @@ class VllmConfig:
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
quant_config
.
maybe_update_config
(
model_config
.
model
)
return
quant_config
return
None
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
d4d98998
...
...
@@ -162,3 +162,9 @@ class QuantizationConfig(ABC):
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
def
maybe_update_config
(
self
,
model_name
:
str
):
# noqa: B027
"""
Interface to update values after config initialization.
"""
pass
vllm/model_executor/layers/quantization/gptq.py
View file @
d4d98998
...
...
@@ -7,6 +7,7 @@ from fractions import Fraction
from
typing
import
Any
,
Optional
,
Union
import
torch
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
...
...
@@ -22,6 +23,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.transformers_utils.config
import
get_safetensors_params_metadata
from
vllm.utils
import
is_list_of
class
GPTQConfig
(
QuantizationConfig
):
...
...
@@ -38,6 +41,7 @@ class GPTQConfig(QuantizationConfig):
lm_head_quantized
:
bool
,
dynamic
:
dict
[
str
,
dict
[
str
,
Union
[
int
,
bool
]]],
autoround_version
:
str
=
""
,
modules_in_block_to_quantize
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
...
...
@@ -75,15 +79,20 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is "
f
"supported for GPTQ, but got
{
self
.
weight_bits
}
bits."
)
self
.
modules_in_block_to_quantize
=
modules_in_block_to_quantize
or
[]
# used to identify GPTQ model quantized by autoround
self
.
autoround_version
=
autoround_version
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
), "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
), "
f
"dynamic=
{
self
.
dynamic
}
"
)
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
), "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"dynamic=
{
self
.
dynamic
}
, "
f
"modules_in_block_to_quantize=
{
self
.
modules_in_block_to_quantize
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
...
...
@@ -114,8 +123,10 @@ class GPTQConfig(QuantizationConfig):
default
=
False
)
autoround_version
=
cls
.
get_from_keys_or
(
config
,
[
"autoround_version"
],
default
=
""
)
modules_in_block_to_quantize
=
cls
.
get_from_keys_or
(
config
,
[
"modules_in_block_to_quantize"
],
default
=
None
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
,
dynamic
,
autoround_version
)
dynamic
,
autoround_version
,
modules_in_block_to_quantize
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
...
...
@@ -136,6 +147,35 @@ class GPTQConfig(QuantizationConfig):
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
):
if
self
.
modules_in_block_to_quantize
is
not
None
:
self
.
modules_in_block_to_quantize
=
hf_to_vllm_mapper
.
apply_list
(
self
.
modules_in_block_to_quantize
)
def
maybe_update_config
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
):
if
self
.
modules_in_block_to_quantize
:
if
is_list_of
(
self
.
modules_in_block_to_quantize
,
list
):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self
.
modules_in_block_to_quantize
=
[
item
for
sublist
in
self
.
modules_in_block_to_quantize
for
item
in
sublist
]
return
unquant_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
metadata
=
get_safetensors_params_metadata
(
model_name
,
revision
=
revision
)
quant_layers
:
set
[
str
]
=
{
param_name
.
rsplit
(
"."
,
1
)[
0
]
for
param_name
,
info
in
metadata
.
items
()
if
(
dtype
:
=
info
.
get
(
'dtype'
,
None
))
and
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype
]
not
in
unquant_dtypes
}
self
.
modules_in_block_to_quantize
=
list
(
quant_layers
)
class
ExllamaState
(
Enum
):
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
d4d98998
...
...
@@ -5,6 +5,7 @@ from copy import deepcopy
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
...
...
@@ -35,6 +36,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
RowvLLMParameter
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.transformers_utils.config
import
get_safetensors_params_metadata
from
vllm.utils
import
is_list_of
logger
=
init_logger
(
__name__
)
...
...
@@ -71,10 +74,16 @@ class GPTQMarlinConfig(QuantizationConfig):
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
dict
[
str
,
dict
[
str
,
Union
[
int
,
bool
]]],
full_config
:
dict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
dict
[
str
,
dict
[
str
,
Union
[
int
,
bool
]]],
full_config
:
dict
[
str
,
Any
],
modules_in_block_to_quantize
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
...
...
@@ -121,15 +130,19 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
self
.
modules_in_block_to_quantize
=
modules_in_block_to_quantize
or
[]
# used to identify GPTQ model quantized by autoround
self
.
autoround_version
=
full_config
.
get
(
"autoround_version"
,
""
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
), "
f
"dynamic=
{
self
.
dynamic
}
"
)
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"dynamic=
{
self
.
dynamic
}
, "
f
"modules_in_block_to_quantize=
{
self
.
modules_in_block_to_quantize
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
...
...
@@ -158,8 +171,11 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
modules_in_block_to_quantize
=
cls
.
get_from_keys_or
(
config
,
[
"modules_in_block_to_quantize"
],
default
=
None
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
lm_head_quantized
,
dynamic
,
config
)
lm_head_quantized
,
dynamic
,
config
,
modules_in_block_to_quantize
)
@
classmethod
def
override_quantization_method
(
...
...
@@ -223,6 +239,35 @@ class GPTQMarlinConfig(QuantizationConfig):
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
)
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
):
if
self
.
modules_in_block_to_quantize
is
not
None
:
self
.
modules_in_block_to_quantize
=
hf_to_vllm_mapper
.
apply_list
(
self
.
modules_in_block_to_quantize
)
def
maybe_update_config
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
):
if
self
.
modules_in_block_to_quantize
:
if
is_list_of
(
self
.
modules_in_block_to_quantize
,
list
):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self
.
modules_in_block_to_quantize
=
[
item
for
sublist
in
self
.
modules_in_block_to_quantize
for
item
in
sublist
]
return
unquant_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
metadata
=
get_safetensors_params_metadata
(
model_name
,
revision
=
revision
)
quant_layers
:
set
[
str
]
=
{
param_name
.
rsplit
(
"."
,
1
)[
0
]
for
param_name
,
info
in
metadata
.
items
()
if
(
dtype
:
=
info
.
get
(
'dtype'
,
None
))
and
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype
]
not
in
unquant_dtypes
}
self
.
modules_in_block_to_quantize
=
list
(
quant_layers
)
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
...
...
vllm/model_executor/layers/quantization/utils/gptq_utils.py
View file @
d4d98998
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
from
copy
import
deepcopy
from
fractions
import
Fraction
from
types
import
MappingProxyType
from
typing
import
Optional
,
Union
import
regex
as
re
...
...
@@ -70,6 +72,49 @@ def get_dynamic_override(
return
default_value
def
is_layer_gptq_quantized
(
prefix
:
str
,
quantized_layers
:
list
[
str
],
fused_mapping
:
Mapping
[
str
,
list
[
str
]]
=
MappingProxyType
({})
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
# GPTQ's `modules_in_block_to_quantize`:
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
# Full prefix ["model.layers.0.self_attn.q_proj"]
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
fused_mapping
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
fused_mapping
[
proj_name
]
]
is_quantized
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_quantized
=
any
(
layer
in
shard_prefix
for
layer
in
quantized_layers
)
if
is_quantized
is
None
:
is_quantized
=
is_shard_quantized
elif
is_shard_quantized
!=
is_quantized
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_quantized
=
any
(
layer
in
prefix
for
layer
in
quantized_layers
)
assert
is_quantized
is
not
None
return
is_quantized
def
get_linear_quant_method
(
config
:
QuantizationConfig
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -80,10 +125,15 @@ def get_linear_quant_method(
parallel_lm_head_quantized
=
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
if
isinstance
(
layer
,
LinearBase
)
or
parallel_lm_head_quantized
:
is_layer_quantized
=
is_layer_gptq_quantized
(
prefix
=
prefix
,
quantized_layers
=
cloned_config
.
modules_in_block_to_quantize
,
fused_mapping
=
cloned_config
.
packed_modules_mapping
)
# False = skip module, None = no override, else = Positive match
if
get_dynamic_override
(
# noqa: E712
cloned_config
,
# noqa: E712
layer_name
=
prefix
)
==
False
:
# noqa: E712
layer_name
=
prefix
)
==
False
or
(
not
is_layer_quantized
):
# noqa: E712
if
parallel_lm_head_quantized
:
return
UnquantizedEmbeddingMethod
()
return
UnquantizedLinearMethod
()
...
...
vllm/model_executor/models/keye.py
View file @
d4d98998
...
...
@@ -25,9 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
...
@@ -1281,11 +1278,6 @@ class BaseKeyeModule(nn.Module):
raise
ValueError
(
"Only image or video modality is supported"
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
:
PretrainedConfig
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -1297,14 +1289,14 @@ class BaseKeyeModule(nn.Module):
self
.
visual
=
KeyeSiglipVisionModel
(
config
.
vision_config
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
self
.
mlp_AR
=
self
.
_build_projector
(
config
,
config
.
vision_config
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"mlp_AR"
),
)
...
...
vllm/model_executor/models/minicpmo.py
View file @
d4d98998
...
...
@@ -28,7 +28,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BatchFeature
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
transformers.models.whisper.modeling_whisper
import
(
ACT2FN
,
WhisperAttention
,
...
...
@@ -36,10 +36,6 @@ from transformers.models.whisper.modeling_whisper import (ACT2FN,
WhisperEncoder
)
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargsItems
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
NestedTensors
)
...
...
@@ -548,36 +544,6 @@ class MiniCPMO(MiniCPMV2_6):
self
.
audio_token_id
=
None
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/openbmb/MiniCPM-o-2_6-int4
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
init_vision_module
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
# MiniCPMO GPTQ model leave vpm unquantized.
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
return
super
().
init_vision_module
(
config
,
quant_config
,
prefix
)
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
# MiniCPMO GPTQ model leave resampler unquantized.
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
return
super
().
init_resampler
(
embed_dim
,
vision_dim
,
quant_config
,
prefix
)
def
init_audio_module
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# Do not use parameters temporarily
audio_config
=
self
.
config
.
audio_config
...
...
vllm/model_executor/models/ovis.py
View file @
d4d98998
...
...
@@ -31,9 +31,6 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.models.aimv2
import
AIMv2Model
from
vllm.model_executor.models.siglip
import
SiglipVisionModel
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
...
...
@@ -416,7 +413,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self
.
visual_tokenizer
=
VisualTokenizer
(
config
=
config
.
visual_tokenizer_config
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.visual_tokenizer"
,
)
...
...
@@ -430,14 +427,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
get_language_model
().
make_empty_intermediate_tensors
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
OvisImagePatchInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
d4d98998
...
...
@@ -52,9 +52,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -1015,8 +1012,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
visual
=
Qwen2_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
self
.
quant_config
),
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
...
...
@@ -1032,13 +1028,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_maybe_ignore_quant_config
(
self
,
config
:
Optional
[
QuantizationConfig
]):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if
isinstance
(
config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
d4d98998
...
...
@@ -50,9 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -1270,7 +1267,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
visual
=
Qwen2VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
...
...
@@ -1286,14 +1283,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
# See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
d4d98998
...
...
@@ -46,9 +46,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -149,24 +146,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
f
"
{
prefix
}
.gate"
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization while AutoRound does.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
))
and
not
quant_config
.
autoround_version
:
return
None
return
quant_config
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
hidden_states
.
dim
(
...
...
@@ -699,4 +683,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
\ No newline at end of file
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/qwen3_next.py
View file @
d4d98998
...
...
@@ -41,9 +41,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -119,12 +116,11 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
)
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen3NextMLP
(
...
...
@@ -142,16 +138,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
1
,
bias
=
False
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization while AutoRound does.
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
))
and
not
quant_config
.
autoround_version
:
return
None
return
quant_config
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
d4d98998
...
...
@@ -50,9 +50,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
...
@@ -1058,7 +1055,7 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
...
...
@@ -1116,13 +1113,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
zero_
()
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
...
...
vllm/model_executor/models/qwen3_vl_moe.py
View file @
d4d98998
...
...
@@ -322,7 +322,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
)
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
...
...
vllm/transformers_utils/config.py
View file @
d4d98998
...
...
@@ -4,6 +4,7 @@
import
json
import
os
import
time
from
dataclasses
import
asdict
from
functools
import
cache
,
partial
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
...
...
@@ -27,7 +28,8 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config_parser_base
import
ConfigParserBase
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
(
check_gguf_file
,
parse_safetensors_file_metadata
)
if
envs
.
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -999,6 +1001,34 @@ def try_get_tokenizer_config(
return
None
def
get_safetensors_params_metadata
(
model
:
str
,
*
,
revision
:
Optional
[
str
]
=
None
,
)
->
dict
[
str
,
Any
]:
"""
Get the safetensors metadata for remote model repository.
"""
full_metadata
=
{}
if
(
model_path
:
=
Path
(
model
)).
exists
():
safetensors_to_check
=
model_path
.
glob
(
"*.safetensors"
)
full_metadata
=
{
param_name
:
info
for
file_path
in
safetensors_to_check
if
file_path
.
is_file
()
for
param_name
,
info
in
parse_safetensors_file_metadata
(
file_path
).
items
()
}
else
:
repo_mt
=
try_get_safetensors_metadata
(
model
,
revision
=
revision
)
if
repo_mt
and
(
files_mt
:
=
repo_mt
.
files_metadata
):
full_metadata
=
{
param_name
:
asdict
(
info
)
for
file_mt
in
files_mt
.
values
()
for
param_name
,
info
in
file_mt
.
tensors
.
items
()
}
return
full_metadata
def
_download_mistral_config_file
(
model
,
revision
)
->
dict
:
config_file_name
=
"params.json"
config_dict
=
get_hf_file_to_dict
(
config_file_name
,
model
,
revision
)
...
...
vllm/transformers_utils/utils.py
View file @
d4d98998
...
...
@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
struct
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
vllm.envs
import
VLLM_MODEL_REDIRECT_PATH
from
vllm.logger
import
init_logger
...
...
@@ -97,3 +98,11 @@ def maybe_model_redirect(model: str) -> str:
return
redirect_model
return
model
def
parse_safetensors_file_metadata
(
path
:
Union
[
str
,
PathLike
])
->
dict
[
str
,
Any
]:
with
open
(
path
,
"rb"
)
as
f
:
length_of_metadata
=
struct
.
unpack
(
'<Q'
,
f
.
read
(
8
))[
0
]
metadata
=
json
.
loads
(
f
.
read
(
length_of_metadata
).
decode
(
'utf-8'
))
return
metadata
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment