Unverified Commit 7ff7a638 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
parent 686006a2
......@@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type
import torch
from torch import nn
......@@ -59,6 +59,7 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
......
......@@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
......@@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity
weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
if self.sparsity_scheme_map:
is_ignored = False
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
with suppress(ValueError):
is_ignored = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_ignore_list)
# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme
if not is_ignored:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
sparsity_scheme = self.sparsity_scheme_map[matched_target]
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
......
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
......@@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
......@@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
......@@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False
def _handle_fused_layers(func):
"""
Decorator to handle fused layers by mapping vllm fused layer names
to their corresponding unfused layer names for quantization/pruning schemes.
"""
# fused_layer_name -> unfused_layer_name
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}
def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)
return fused_layer_handler
@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
......@@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True)
or _match_fused_layer(layer_name, targets))
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None:
raise ValueError(
......@@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
return False
def _match_fused_layer(layer_name: str,
target_layers: Iterable[str]) -> Optional[str]:
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers.
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
......@@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str,
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# Split into parent path and layer type
# e.g., "model.layers.0.self_attn" and "qkv_proj"
parent_path = ".".join(layer_name.split(".")[:-1])
layer_type = layer_name.split(".")[-1]
if layer_type not in FUSED_LAYER_NAME_MAPPING:
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
None)
if fused is None:
return None
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# Look for a target layer that:
# 1. Has the same parent path
# 2. Ends with one of the possible individual layer types
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
is_same_parent = parent_path in target
is_matching_type = any(type_suffix in target
for type_suffix in possible_layer_types)
if is_same_parent and is_matching_type and all(
(f"{parent_path}.{type_suffix}" in target_layers)
for type_suffix in possible_layer_types):
return target
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return None
return unfused_matches[0] if all(unfused_matches) else None
......@@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"]
......@@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
......@@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig):
module: torch.nn.Module) -> Dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
......
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Any, Iterable, Optional
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
def deep_compare(dict1: Any, dict2: Any) -> bool:
......@@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
......@@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in fused_mapping:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
......
# SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional, Tuple
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
import numpy
import torch
......@@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
......@@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
......
......@@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (ParamMapping,
configure_quant_config,
get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -113,6 +114,9 @@ def _initialize_model(
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
......
......@@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
......@@ -138,3 +140,23 @@ class ParamMapping:
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(quant_config: QuantizationConfig,
model_class: Type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
......@@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
......@@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
......@@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.activation_func = SiluAndMul()
......@@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, hidden_states):
......@@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
......@@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.embedding")
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
......@@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
......@@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
\ No newline at end of file
instance_cls = ChatGLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
......@@ -74,11 +74,13 @@ class Attention(nn.Module):
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
......@@ -101,6 +103,7 @@ class MLP(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
......@@ -109,11 +112,13 @@ class MLP(nn.Module):
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -137,7 +142,9 @@ class TransformerLayer(nn.Module):
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
......@@ -164,7 +171,7 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
......@@ -181,6 +188,7 @@ class GLU(nn.Module):
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
......@@ -222,7 +230,8 @@ class GLU(nn.Module):
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
......@@ -230,12 +239,15 @@ class GLU(nn.Module):
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x):
x, _ = self.linear_proj(x)
......@@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module):
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
......
......@@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
......@@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
instance_cls = _SUPPORT_VERSION.get(version)
if instance_cls is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config=vllm_config, prefix=prefix)
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
......@@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
......@@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
prefix: str = "",
) -> QWenBaseModel:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "visual"): # noqa: SIM108
instance_cls = QWenVL
# Initialize LLM
else:
return QWenLLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = QWenLLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment