Unverified Commit 0240402c authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc]Add BNB quantization for MolmoForCausalLM (#11551)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 55509c21
...@@ -11,7 +11,8 @@ import os ...@@ -11,7 +11,8 @@ import os
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, cast)
import gguf import gguf
import huggingface_hub import huggingface_hub
...@@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support # Store all module names (from transformers) that support
# BNB quantization. # BNB quantization.
self.target_modules: List[str] = [] self.target_modules: List[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
def _get_weight_files( def _get_weight_files(
self, self,
...@@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors: if use_safetensors:
return safetensors_weights_iterator(hf_weights_files) iterator = safetensors_weights_iterator(hf_weights_files)
else: else:
return pt_weights_iterator(hf_weights_files) iterator = pt_weights_iterator(hf_weights_files)
for name, param in iterator:
# mapping weight names from transformers to vllm.
yield self.weight_mapper(name), param
def _get_quantized_weights_iterator( def _get_quantized_weights_iterator(
self, self,
...@@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try: try:
import bitsandbytes import bitsandbytes
if bitsandbytes.__version__ < "0.44.0": if bitsandbytes.__version__ < "0.45.0":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0.") "install bitsandbytes>=0.45.0.")
except ImportError as err: except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.44.0 via " raise ImportError("Please install bitsandbytes>=0.45.0 via "
"`pip install bitsandbytes>=0.44.0` to use " "`pip install bitsandbytes>=0.45.0` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights( hf_weights_files, use_safetensors = self._prepare_weights(
...@@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if isinstance(module, (LinearBase, )): if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1] last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []): if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names. # Map vllm's names to transformers's names.
for sub_name in sub_modules: for sub_name in sub_modules:
self.target_modules.append( self.target_modules.append(
name.replace(last_name, sub_name)) name.replace(last_name, sub_name))
...@@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"Model {type(model).__name__} does not support BitsAndBytes " f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.") "quantization yet.")
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk # Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP # we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {} self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
......
...@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module): ...@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
return output return output
class MolmoMLP(nn.Module): class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)
class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp.""" """Molmo's LLM mlp."""
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
input_dim: Optional[int] = None, input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None:
proj_name: str = "gate_up_proj") -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2 self.intermediate_size = config.intermediate_size // 2
# Molmo's LLM proj weights are already merged into the disk, while self.gate_up_proj = MergedColumnParallelLinear(
# image_projector proj is separate. If the same proj_name were used, it input_dim or self.hidden_size,
# would create ambiguity and make it difficult to support BNB and LoRA. [self.intermediate_size] * 2,
self.proj_name = proj_name bias=False,
setattr( quant_config=quant_config,
self, proj_name, )
MergedColumnParallelLinear( # Activation function.
input_dim or self.hidden_size, self.act_fn = SwiGLU()
[self.intermediate_size] * 2, # Feed-forward output projection.
bias=False, self.down_proj = RowParallelLinear(
quant_config=quant_config, self.intermediate_size,
)) self.hidden_size,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class ImageProjectorMLP(nn.Module):
"""Molmo's image_projector mlp."""
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
self.merged_linear = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function. # Activation function.
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module): ...@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
gate_up, _ = getattr(self, self.proj_name)(x) gate_up, _ = self.merged_linear(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
...@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module): ...@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = MolmoMLP(config, self.mlp = LanuageModelMLP(config, quant_config=quant_config)
quant_config=quant_config,
proj_name="gate_up_proj")
# LayerNorm # LayerNorm
assert config.layer_norm_type == "rms" assert config.layer_norm_type == "rms"
...@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module): ...@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
vision_config, vision_config,
nlayers=len(self.vit_layers), nlayers=len(self.vit_layers),
quant_config=quant_config) quant_config=quant_config)
self.image_projector = MolmoMLP( self.image_projector = ImageProjectorMLP(
config, config,
input_dim=vision_config.image_emb_dim, input_dim=vision_config.image_emb_dim,
quant_config=quant_config, quant_config=quant_config,
proj_name="merged_linear",
) )
image_dim = vision_config.image_emb_dim * len(self.vit_layers) image_dim = vision_config.image_emb_dim * len(self.vit_layers)
...@@ -842,10 +880,6 @@ class MolmoModel(nn.Module): ...@@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
}, },
) )
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment