"vscode:/vscode.git/clone" did not exist on "2c2f955647539db6515128871b75325e0f79c2ea"
Unverified Commit 9a0d0b75 authored by Vincent Zhong's avatar Vincent Zhong Committed by GitHub
Browse files

[Performance] Improve Qwen RMSNorm by replacing with native RMSNorm op (#9709)

parent ba861293
......@@ -31,7 +31,6 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
......@@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
......@@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module):
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
self.norm1 = RMSNorm(dim, eps=1e-6)
self.norm2 = RMSNorm(dim, eps=1e-6)
if attn_implementation is None:
softmax_in_single_precision = False
......@@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
hidden_states = self.norm1(x2d).reshape(S, B, H)
# Attention expects [B, S, H]
hidden_states = rearrange(hidden_states, "s b h -> b s h")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
norm2 = self.norm2(x)
mlp = self.mlp(norm2)
x = x + mlp
attn = rearrange(attn, "b s h -> s b h")
# norm2 with fused residual-add: also 2D
attn2d = attn.reshape(-1, H)
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
x_norm = x_norm_2d.reshape(S, B, H)
x_after_add = x_after_add_2d.reshape(S, B, H)
# MLP and final residual
mlp_out = self.mlp(x_norm)
x = x_after_add + mlp_out
return x
......@@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
self.ln_q = RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
......@@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
# x expected shape: [S, B, context_dim]
S, B, D = x.shape
x2d = x.reshape(-1, D)
x2d = self.ln_q(x2d) # RMSNorm expects 2D
x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel, _ = mlp_fc1(x2d)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out
......@@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
# Move window_index to the same device as x before using it to index x
window_index = window_index.to(device=x.device)
# Ensure rotary_pos_emb is on the same device/dtype as x
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
seq_len, _ = x.size()
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
......@@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
position_embeddings = (
position_embeddings[0].to(x.device, x.dtype),
position_embeddings[1].to(x.device, x.dtype),
)
# compute cu_seqlens
# compute cu_seqlens - move cu_seqlens to GPU and make it int32
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
torch.tensor([0], device=x.device, dtype=torch.int32),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
.cumsum(dim=0)
.to(device=x.device, dtype=torch.int32),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment