Unverified Commit a3a3ee4e authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 87054a57
......@@ -39,7 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
from vllm.model_executor.model_loader.utils import (ParamMapping,
get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
......@@ -983,21 +984,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _get_bnb_target_modules(self, model: nn.Module) -> None:
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping: Dict[str, List[str]] = {}
for orig, (
packed,
idx,
) in model.bitsandbytes_stacked_params_mapping.items():
if packed not in inverse_stacked_mapping:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
# Map vllm's names to transformers's names.
for sub_name in sub_modules:
self.target_modules.append(
......@@ -1018,15 +1009,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
"quantization yet. No 'packed_modules_mapping' found.")
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
......@@ -1109,7 +1104,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for shard_name, (
weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items():
) in self.modules_mapping.inverse_packed_mapping.items():
shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
......
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type
import torch
from torch import nn
......@@ -49,3 +50,26 @@ def get_model_architecture(
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: Dict[str, List[str]]
inverse_packed_mapping: Dict[str, Tuple[str,
int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
......@@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
*,
......
......@@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"c_fc_0": ("gate_up_proj", 0),
"c_fc_1": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -409,9 +409,9 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
......
......@@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......
......@@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"down_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
......
......@@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
......
......@@ -463,14 +463,10 @@ def init_vision_tower_for_llava(
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
......
......@@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......
......@@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
......@@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
......@@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
......
......@@ -1107,14 +1107,9 @@ class MllamaForCausalLM(nn.Module):
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......
......@@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -329,13 +329,9 @@ class OPTModel(nn.Module):
class OPTForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc2",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
embedding_modules = {}
embedding_padding_modules = []
......
......@@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj",
],
}
# BitandBytes specific attributes
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment