Unverified Commit bc24205b authored by ryang's avatar ryang Committed by GitHub
Browse files

Support BNB quantization for llama/mllama (#5038)


Co-authored-by: default avatarYuhao Yang <yyh073@foxmail.com>
parent 3efc8e2d
...@@ -1074,7 +1074,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -1074,7 +1074,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_type = model_config.hf_config.model_type model_type = model_config.hf_config.model_type
for quant_param_name in quant_state_dict: for quant_param_name in quant_state_dict:
non_stacked_param_name = quant_param_name non_stacked_param_name = quant_param_name
if model_type == "mllama" and "vision_model" in quant_param_name:
# adapt to VisionAttention
quant_param_name = quant_param_name.replace(
"self_attn.o_proj", "self_attn.proj"
)
shard_index = 0 shard_index = 0
for shard_name, ( for shard_name, (
weight_name, weight_name,
......
...@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm ...@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: config_mllama.MllamaVisionConfig, config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
is_gated: bool = False, is_gated: bool = False,
prefix: str = "", prefix: str = "",
): ):
...@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
self.num_attention_heads, self.num_attention_heads,
self.hidden_size, self.hidden_size,
use_qkv_parallel=True, use_qkv_parallel=True,
quant_config=None, quant_config=quant_config,
dropout=0.0, dropout=0.0,
use_context_forward=False, use_context_forward=False,
softmax_in_single_precision=False, softmax_in_single_precision=False,
flatten_batch=False, flatten_batch=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix)) self.mlp = MllamaVisionMLP(
config, quant_config, prefix=add_prefix("mlp", prefix)
)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
...@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module): ...@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
def __init__( def __init__(
self, self,
config: config_mllama.MllamaVisionConfig, config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_layers=32, num_layers=32,
is_gated=False, is_gated=False,
output_hidden_states=None, output_hidden_states=None,
...@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module): ...@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MllamaVisionEncoderLayer( MllamaVisionEncoderLayer(
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix) config,
quant_config,
is_gated,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(num_layers) for i in range(num_layers)
] ]
...@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module): ...@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
class MllamaVisionModel(nn.Module): class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""): def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
...@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module): ...@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
# encoders # encoders
self.transformer = MllamaVisionEncoder( self.transformer = MllamaVisionEncoder(
config, config,
quant_config,
config.num_hidden_layers, config.num_hidden_layers,
is_gated=False, is_gated=False,
output_hidden_states=config.intermediate_layers_indices, output_hidden_states=config.intermediate_layers_indices,
...@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module): ...@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
) )
self.global_transformer = MllamaVisionEncoder( self.global_transformer = MllamaVisionEncoder(
config, config,
quant_config,
config.num_global_layers, config.num_global_layers,
is_gated=True, is_gated=True,
prefix=add_prefix("global_transformer", prefix), prefix=add_prefix("global_transformer", prefix),
...@@ -765,6 +780,27 @@ class MllamaForCausalLM(nn.Module): ...@@ -765,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
class MllamaForConditionalGeneration(nn.Module): class MllamaForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__( def __init__(
self, self,
config: config_mllama.MllamaConfig, config: config_mllama.MllamaConfig,
...@@ -772,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -772,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles self.max_num_tiles = config.vision_config.max_num_tiles
...@@ -782,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -782,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
self.image_size = config.vision_config.image_size self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel( self.vision_model = MllamaVisionModel(
config.vision_config, prefix=add_prefix("vision_model", prefix) config.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_model", prefix),
) )
self.language_model = MllamaForCausalLM( self.language_model = MllamaForCausalLM(
config.text_config, config.text_config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("language_model", prefix), prefix=add_prefix("language_model", prefix),
) )
self.multi_modal_projector = nn.Linear( self.multi_modal_projector = ReplicatedLinear(
config.vision_config.vision_output_dim, config.vision_config.vision_output_dim,
config.text_config.hidden_size, config.text_config.hidden_size,
bias=True, bias=True,
quant_config=quant_config,
prefix="multi_modal_projector",
) )
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False self.capture_mode = False
...@@ -959,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -959,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_states = self.vision_model( cross_attention_states = self.vision_model(
batched_images, batched_ar_ids, batched_ar_mask batched_images, batched_ar_ids, batched_ar_mask
) )
cross_attention_states = self.multi_modal_projector(cross_attention_states) cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states
)
bs, _, _, _, image_token_dim = cross_attention_states.shape bs, _, _, _, image_token_dim = cross_attention_states.shape
cross_attention_states = cross_attention_states.view( cross_attention_states = cross_attention_states.view(
...@@ -1013,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -1013,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
if "vision_model" in name: if "vision_model" in name:
# adapt to VisionAttention # adapt to VisionAttention
name = name.replace("self_attn.o_proj", "self_attn.proj") name = name.replace("self_attn.o_proj", "self_attn.proj")
param = params_dict.pop(name) param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
""" """
Usage: Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch python3 -m unittest test_bnb.TestVisionModel.test_vlm
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion python3 -m unittest test_bnb.TestLanguageModel.test_mmlu
""" """
import base64 import base64
...@@ -31,10 +31,13 @@ from sglang.test.test_utils import ( ...@@ -31,10 +31,13 @@ from sglang.test.test_utils import (
VISION_MODELS = [ VISION_MODELS = [
("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"), ("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"), ("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
("unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "llama_3_vision"),
("unsloth/Llama-3.2-11B-Vision-bnb-4bit", "llama_3_vision"),
] ]
LANGUAGE_MODELS = [ LANGUAGE_MODELS = [
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
"unsloth/Qwen2-7B-Instruct-bnb-4bit", "unsloth/Qwen2-7B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
] ]
# image # image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment