"vscode:/vscode.git/clone" did not exist on "ab0c0ec68f73680cb224526b2d29f2e29a936138"
Unverified Commit a391f73a authored by Kevin Xiang Li's avatar Kevin Xiang Li Committed by GitHub
Browse files

Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXiang (Kevin) Li <lik@nvidia.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 25c73959
...@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( ...@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__) ...@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
class Qwen2_5_VLMLP(nn.Module): class Qwen2_5_VLMLP(nn.Module):
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
...@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module): ...@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_proj = ColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
in_features, input_size=in_features,
hidden_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix), prefix=add_prefix("gate_up_proj", prefix),
)
self.up_proj = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
hidden_features, hidden_features,
...@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module): ...@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
self.act = ACT2FN[hidden_act] self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel_gate, _ = self.gate_proj(x) gate_up, _ = self.gate_up_proj(x)
x_parallel_gate = self.act(x_parallel_gate) gate, up = gate_up.chunk(2, dim=-1)
x_parallel_up, _ = self.up_proj(x) x = self.act(gate) * up
x_parallel = x_parallel_gate * x_parallel_up x_down, _ = self.down_proj(x)
x, _ = self.down_proj(x_parallel) return x_down
return x
class Qwen2_5_VisionBlock(nn.Module): class Qwen2_5_VisionBlock(nn.Module):
...@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.blocks[0].mlp.gate_proj.weight.device return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
...@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor) ...@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module): class Qwen2_5_VLForConditionalGeneration(nn.Module):
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [ default_bitsandbytes_target_modules = [
".gate_proj.", ".gate_up_proj.",
".down_proj.", ".down_proj.",
".up_proj.",
".q_proj.", ".q_proj.",
".k_proj.", ".k_proj.",
".v_proj.", ".v_proj.",
...@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if "visual" in name: if (
"visual" in name
and "up_proj" not in name
and "gate_proj" not in name
):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment