Unverified Commit a391f73a authored by Kevin Xiang Li's avatar Kevin Xiang Li Committed by GitHub
Browse files

Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXiang (Kevin) Li <lik@nvidia.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 25c73959
......@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
class Qwen2_5_VLMLP(nn.Module):
def __init__(
self,
in_features: int,
......@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
prefix: str = "",
):
super().__init__()
self.gate_proj = ColumnParallelLinear(
in_features,
hidden_features,
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix),
)
self.up_proj = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("up_proj", prefix),
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
hidden_features,
......@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel_gate, _ = self.gate_proj(x)
x_parallel_gate = self.act(x_parallel_gate)
x_parallel_up, _ = self.up_proj(x)
x_parallel = x_parallel_gate * x_parallel_up
x, _ = self.down_proj(x_parallel)
return x
gate_up, _ = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
x = self.act(gate) * up
x_down, _ = self.down_proj(x)
return x_down
class Qwen2_5_VisionBlock(nn.Module):
......@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.gate_proj.weight.device
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
......@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".gate_up_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
......@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "visual" in name:
if (
"visual" in name
and "up_proj" not in name
and "gate_proj" not in name
):
continue
name = name.replace(weight_name, param_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment