Unverified Commit 7ff7a638 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
parent 686006a2
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Mapping, Optional, Type
import torch import torch
from torch import nn from torch import nn
...@@ -59,6 +59,7 @@ def method_has_implemented_embedding( ...@@ -59,6 +59,7 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:
......
...@@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -83,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names # TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore): if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
...@@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -379,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity # Will be empty for models with only sparsity
weight_quant = input_quant = None weight_quant = input_quant = None
sparsity_scheme: Optional[SparsityCompressionConfig] = None
if self.target_scheme_map: if self.target_scheme_map:
matched_target = find_matched_target( matched_target = find_matched_target(
layer_name=layer_name, layer_name=layer_name,
module=layer, module=layer,
targets=self.target_scheme_map.keys()) targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target] scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights") weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations") input_quant = scheme_dict.get("input_activations")
if self.sparsity_scheme_map: # Find the sparsity scheme of the layer
is_ignored = False # assume that fused layers inerhit first component's sparsity scheme
with suppress(ValueError): sparsity_targets = (self.sparsity_scheme_map.keys() -
is_ignored = find_matched_target( set(self.sparsity_ignore_list))
layer_name=layer_name, sparsity_scheme: Optional[SparsityCompressionConfig] = None
module=layer, with suppress(ValueError):
targets=self.sparsity_ignore_list) matched_target = find_matched_target(
layer_name=layer_name,
# if the layer is in the sparsity ignore list, module=layer,
# we should not apply any sparsity scheme targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
if not is_ignored: sparsity_scheme = self.sparsity_scheme_map[matched_target]
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
if self.supports_cutlass_24(weight_quant=weight_quant, if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant, input_quant=input_quant,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from typing import Iterable, Optional from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def is_activation_quantization_format(format: str) -> bool: def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [ _ACTIVATION_QUANTIZATION_FORMATS = [
...@@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool: ...@@ -19,8 +17,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str], def should_ignore_layer(
ignore: Iterable[str]) -> bool: layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None: if layer_name is None:
return False return False
...@@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str], ...@@ -32,8 +33,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name # in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that # from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme. # each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore: if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [
...@@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str, ...@@ -79,55 +80,12 @@ def check_equal_or_regex_match(layer_name: str,
return False return False
def _handle_fused_layers(func): def find_matched_target(
""" layer_name: Optional[str],
Decorator to handle fused layers by mapping vllm fused layer names module: Module,
to their corresponding unfused layer names for quantization/pruning schemes. targets: Iterable[str],
""" fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
# fused_layer_name -> unfused_layer_name ) -> str:
fused_layer_map = {
"qkv_proj": "q_proj",
"gate_up_proj": "up_proj",
}
def fused_layer_handler(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> Optional[str]:
"""
Wrapper function specifically designed to support the
find_matched_target function.
It handles cases where the provided layer name corresponds to a
fused layer in vllm, mapping it to its equivalent unfused layer name
based on the predefined fused_layer_map. If the original layer name
raises a ValueError in the wrapped function, this handler
will attempt to resolve the issue by substituting with unfused
layer name.
:param layer_name: Name of the layer, which may be fused.
:param module: An instance of torch.nn.Module.
:param targets: A list of target names or patterns to match.
:return: The result of the wrapped find_matched_target function with
the resolved layer name.
:raises ValueError: If the layer name cannot be resolved to a
valid target.
"""
try:
return func(layer_name, module, targets)
except ValueError:
if layer_name is None:
layer_name = ""
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
unfused_proj_name = fused_layer_map.get(fused_proj_name,
fused_proj_name)
new_layer_name = f"{parent_name}.{unfused_proj_name}"
return func(new_layer_name, module, targets)
return fused_layer_handler
@_handle_fused_layers
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
""" """
Helper function to look up which "target" in the compressed-tensors Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to. config that a layer corresponds to.
...@@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module, ...@@ -141,19 +99,25 @@ def find_matched_target(layer_name: Optional[str], module: Module,
First, we try to match the layer_name with a target First, we try to match the layer_name with a target
Second, we try to match the module's name with a target Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name :param layer_name: layer name
:param module: torch.nn.Module :param module: torch.nn.Module
:param targets: list of targets to match the layer against :param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
""" """
if layer_name is None: if layer_name is None:
layer_name = "" layer_name = ""
matched_target = (_find_first_match(layer_name, targets) matched_target = (
or _find_first_match(module.__class__.__name__, targets, _find_first_match(layer_name, targets)
True) or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets)) or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None: if matched_target is None:
raise ValueError( raise ValueError(
...@@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str, ...@@ -205,11 +169,19 @@ def _is_equal_or_regex_match(value: str,
return False return False
def _match_fused_layer(layer_name: str, def _match_fused_layer(
target_layers: Iterable[str]) -> Optional[str]: layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
""" """
Match a fused layer name to its corresponding individual layer in Match a fused layer name to its corresponding individual layer in
target_layers. target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples: Examples:
layer_name = "model.layers.0.self_attn.qkv_proj" layer_name = "model.layers.0.self_attn.qkv_proj"
...@@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str, ...@@ -217,27 +189,25 @@ def _match_fused_layer(layer_name: str,
"model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"] "model.layers.0.self_attn.v_proj"]
""" """
# Split into parent path and layer type # find layer_name in mapping
# e.g., "model.layers.0.self_attn" and "qkv_proj" fused = next((key for key in fused_mapping if layer_name.endswith(key)),
parent_path = ".".join(layer_name.split(".")[:-1]) None)
layer_type = layer_name.split(".")[-1] if fused is None:
if layer_type not in FUSED_LAYER_NAME_MAPPING:
return None return None
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type] # expand path of unfused components
unfused_paths = [
# Look for a target layer that: layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
# 1. Has the same parent path ]
# 2. Ends with one of the possible individual layer types
for target in target_layers:
is_same_parent = parent_path in target
is_matching_type = any(type_suffix in target
for type_suffix in possible_layer_types)
if is_same_parent and is_matching_type and all(
(f"{parent_path}.{type_suffix}" in target_layers)
for type_suffix in possible_layer_types):
return target
return None # for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return unfused_matches[0] if all(unfused_matches) else None
...@@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import ( ...@@ -18,8 +18,6 @@ from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import ( from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer) deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"] __all__ = ["QuarkLinearMethod"]
...@@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig): ...@@ -58,7 +56,9 @@ class QuarkConfig(QuantizationConfig):
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude")) exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers): if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
...@@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig): ...@@ -201,8 +201,8 @@ class QuarkConfig(QuantizationConfig):
module: torch.nn.Module) -> Dict[str, Any]: module: torch.nn.Module) -> Dict[str, Any]:
proj_name = layer_name.split(".")[-1] proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING: if proj_name in self.packed_modules_mapping:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from typing import Any, Iterable, Optional from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def deep_compare(dict1: Any, dict2: Any) -> bool: def deep_compare(dict1: Any, dict2: Any) -> bool:
...@@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: ...@@ -20,8 +18,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2 return dict1 == dict2
def should_ignore_layer(layer_name: Optional[str], def should_ignore_layer(
ignore: Iterable[str]) -> bool: layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None: if layer_name is None:
return False return False
...@@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str], ...@@ -33,8 +34,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name # in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that # from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme. # each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING: if proj_name in fused_mapping:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names] # Convert fused_name --> [shard_names]
shard_names = [ shard_names = [
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import List, Optional, Tuple from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
import numpy import numpy
import torch import torch
...@@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types ...@@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
...@@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, ...@@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm) return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj # prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj # proj_name: q_proj
proj_name = prefix.split(".")[-1] proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [ shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] for shard_proj_name in fused_mapping[proj_name]
] ]
is_skipped = None is_skipped = None
......
...@@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import ( ...@@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (ParamMapping, from vllm.model_executor.model_loader.utils import (ParamMapping,
configure_quant_config,
get_model_architecture, get_model_architecture,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -113,6 +114,9 @@ def _initialize_model( ...@@ -113,6 +114,9 @@ def _initialize_model(
model_config = vllm_config.model_config model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__) signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params: if "vllm_config" in all_params and "prefix" in all_params:
......
...@@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module ...@@ -11,6 +11,8 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.config import ModelConfig, ModelImpl from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model, from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model, as_embedding_model,
...@@ -138,3 +140,23 @@ class ParamMapping: ...@@ -138,3 +140,23 @@ class ParamMapping:
if module_name.endswith(key): if module_name.endswith(key):
return key, value return key, value
return None return None
def configure_quant_config(quant_config: QuantizationConfig,
model_class: Type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
...@@ -265,12 +265,14 @@ class GLMAttention(nn.Module): ...@@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...@@ -327,6 +329,7 @@ class GLMMLP(nn.Module): ...@@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -338,6 +341,7 @@ class GLMMLP(nn.Module): ...@@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
...@@ -348,6 +352,7 @@ class GLMMLP(nn.Module): ...@@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -396,7 +401,7 @@ class GLMBlock(nn.Module): ...@@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
# MLP # MLP
self.mlp = GLMMLP(config, quant_config) self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
...@@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module): ...@@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.embedding")
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
...@@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal): SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
...@@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
# Initialize VL # Initialize VL
if hasattr(config, "vision_config"): if hasattr(config, "vision_config"): # noqa: SIM108
return ChatGLMV(vllm_config=vllm_config, prefix=prefix) instance_cls = ChatGLMV
# Initialize LLM # Initialize LLM
else: else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix) instance_cls = ChatGLM
\ No newline at end of file
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
...@@ -74,11 +74,13 @@ class Attention(nn.Module): ...@@ -74,11 +74,13 @@ class Attention(nn.Module):
self.head_dim, self.head_dim,
config.num_heads, config.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
...@@ -101,6 +103,7 @@ class MLP(nn.Module): ...@@ -101,6 +103,7 @@ class MLP(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -109,11 +112,13 @@ class MLP(nn.Module): ...@@ -109,11 +112,13 @@ class MLP(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2",
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -137,7 +142,9 @@ class TransformerLayer(nn.Module): ...@@ -137,7 +142,9 @@ class TransformerLayer(nn.Module):
self.attention = Attention(config, self.attention = Attention(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config) self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size, self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -164,7 +171,7 @@ class Transformer(nn.Module): ...@@ -164,7 +171,7 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
TransformerLayer(config, TransformerLayer(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}") prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
...@@ -181,6 +188,7 @@ class GLU(nn.Module): ...@@ -181,6 +188,7 @@ class GLU(nn.Module):
config, config,
in_features, in_features,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
): ):
""" """
The original implementation is the same as: The original implementation is the same as:
...@@ -222,7 +230,8 @@ class GLU(nn.Module): ...@@ -222,7 +230,8 @@ class GLU(nn.Module):
self.linear_proj = ReplicatedLinear(in_features, self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size) self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU() self.act1 = nn.GELU()
self.act2 = SiluAndMul() self.act2 = SiluAndMul()
...@@ -230,12 +239,15 @@ class GLU(nn.Module): ...@@ -230,12 +239,15 @@ class GLU(nn.Module):
self.merged_proj = MergedColumnParallelLinear( self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2, config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, self.dense_4h_to_h = RowParallelLinear(
config.hidden_size, config.ffn_hidden_size,
bias=False, config.hidden_size,
quant_config=quant_config) bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x): def forward(self, x):
x, _ = self.linear_proj(x) x, _ = self.linear_proj(x)
...@@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module): ...@@ -262,7 +274,8 @@ class EVA2CLIPModel(nn.Module):
prefix=f"{prefix}.transformer") prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config, self.linear_proj = GLU(config,
in_features=config.hidden_size, in_features=config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size, out_channels=config.hidden_size,
kernel_size=2, kernel_size=2,
......
...@@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): ...@@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
""" """
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
...@@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): ...@@ -1489,8 +1490,15 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
version = str(config.version).split(".") version = str(config.version).split(".")
version = tuple([int(x) for x in version]) version = tuple([int(x) for x in version])
# Dispatch class based on version # Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version) instance_cls = _SUPPORT_VERSION.get(version)
if instance_class is None: if instance_cls is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config=vllm_config, prefix=prefix)
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
...@@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): ...@@ -1135,6 +1135,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
""" """
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty. # initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {} packed_modules_mapping = {}
supported_lora_modules = [] supported_lora_modules = []
embedding_modules = {} embedding_modules = {}
...@@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): ...@@ -1146,9 +1147,18 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
prefix: str = "", prefix: str = "",
) -> QWenBaseModel: ) -> QWenBaseModel:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
# Initialize VL # Initialize VL
if hasattr(config, "visual"): if hasattr(config, "visual"): # noqa: SIM108
return QWenVL(vllm_config=vllm_config, prefix=prefix) instance_cls = QWenVL
# Initialize LLM # Initialize LLM
else: else:
return QWenLLM(vllm_config=vllm_config, prefix=prefix) instance_cls = QWenLLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment