Unverified Commit a3a3ee4e authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 87054a57
...@@ -39,7 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -39,7 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture, from vllm.model_executor.model_loader.utils import (ParamMapping,
get_model_architecture,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
...@@ -983,21 +984,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -983,21 +984,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _get_bnb_target_modules(self, model: nn.Module) -> None: def _get_bnb_target_modules(self, model: nn.Module) -> None:
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping: Dict[str, List[str]] = {}
for orig, (
packed,
idx,
) in model.bitsandbytes_stacked_params_mapping.items():
if packed not in inverse_stacked_mapping:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, (LinearBase, )): if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1] last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []): if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
# Map vllm's names to transformers's names. # Map vllm's names to transformers's names.
for sub_name in sub_modules: for sub_name in sub_modules:
self.target_modules.append( self.target_modules.append(
...@@ -1018,15 +1009,19 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -1018,15 +1009,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"The required method 'load_weights' is not defined in class" "The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.") f" {type(model).__name__}.")
if not hasattr(model, "bitsandbytes_stacked_params_mapping"): if not hasattr(model, "packed_modules_mapping"):
raise AttributeError( raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes " f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.") "quantization yet. No 'packed_modules_mapping' found.")
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))
# For some models like Molmo, we need to use hf_to_vllm_mapper # For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights. # to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk # Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP # we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {} self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
...@@ -1109,7 +1104,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -1109,7 +1104,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for shard_name, ( for shard_name, (
weight_name, weight_name,
index, index,
) in model.bitsandbytes_stacked_params_mapping.items(): ) in self.modules_mapping.inverse_packed_mapping.items():
shard_pos = quant_param_name.find(shard_name) shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both # Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Tuple, Type from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type
import torch import torch
from torch import nn from torch import nn
...@@ -49,3 +50,26 @@ def get_model_architecture( ...@@ -49,3 +50,26 @@ def get_model_architecture(
def get_architecture_class_name(model_config: ModelConfig) -> str: def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1] return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: Dict[str, List[str]]
inverse_packed_mapping: Dict[str, Tuple[str,
int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
...@@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__( def __init__(
self, self,
*, *,
......
...@@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"c_fc_0": ("gate_up_proj", 0),
"c_fc_1": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -409,9 +409,9 @@ class FalconModel(nn.Module): ...@@ -409,9 +409,9 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module, SupportsPP): class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
# BitandBytes specific attributes "query_key_value": ["query_key_value"],
bitsandbytes_stacked_params_mapping = {} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer. # Gemma does not apply LoRA to the embedding layer.
embedding_modules = {} embedding_modules = {}
......
...@@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
......
...@@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
......
...@@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral # Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints # from consolidated.safetensors checkpoints
mistral_mapping = { mistral_mapping = {
......
...@@ -463,14 +463,10 @@ def init_vision_tower_for_llava( ...@@ -463,14 +463,10 @@ def init_vision_tower_for_llava(
info=_build_llava_or_pixtral_hf_info, info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder) dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = { packed_modules_mapping = {
# shard_name, weight_name, index "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"q_proj": ("qkv_proj", 0), "gate_up_proj": ["gate_proj", "up_proj"]
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
......
...@@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): ...@@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules` # `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM # are inherited from MiniCPMForCausalLM
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
...@@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj", "kv_proj",
] ]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
...@@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj", "kv_proj",
] ]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
......
...@@ -1107,14 +1107,9 @@ class MllamaForCausalLM(nn.Module): ...@@ -1107,14 +1107,9 @@ class MllamaForCausalLM(nn.Module):
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes packed_modules_mapping = {
bitsandbytes_stacked_params_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"],
# shard_name, weight_name, index "gate_up_proj": ["gate_proj", "up_proj"]
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
...@@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings", "lm_head": "output_embeddings",
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -329,13 +329,9 @@ class OPTModel(nn.Module): ...@@ -329,13 +329,9 @@ class OPTModel(nn.Module):
class OPTForCausalLM(nn.Module, SupportsPP): class OPTForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
# BitandBytes specific attributes "qkv_proj": ["q_proj", "k_proj", "v_proj"],
bitsandbytes_stacked_params_mapping = { "gate_up_proj": ["gate_proj", "up_proj"]
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc2", "fc2",
] ]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
......
...@@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM): ...@@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj", "gate_up_proj",
], ],
} }
# BitandBytes specific attributes
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment