Unverified Commit 035fd2bd authored by Wenlong Wang's avatar Wenlong Wang Committed by GitHub
Browse files

[Multi Modal][Performance] Fused Q,K's apply_rope in more models (#25005)


Signed-off-by: default avatarwwl2755 <wangwenlong2755@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 1cd885bd
...@@ -234,8 +234,9 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -234,8 +234,9 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v)) for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) qk_concat = torch.cat([q, k], dim=0)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
...@@ -261,8 +262,8 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -261,8 +262,8 @@ class Ernie4_5_VisionAttention(nn.Module):
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) ... -> b s ...", "(b s) h d -> s b (h d)",
b=batch_size) b=batch_size).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
...@@ -281,6 +282,8 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -281,6 +282,8 @@ class Ernie4_5_VisionAttention(nn.Module):
output_i = rearrange(output_i, "b h s d -> b s h d ") output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
......
...@@ -315,8 +315,10 @@ class Glm4vVisionAttention(nn.Module): ...@@ -315,8 +315,10 @@ class Glm4vVisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v)) for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) # [2 * b, s, heads, head_dim]
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
...@@ -341,8 +343,8 @@ class Glm4vVisionAttention(nn.Module): ...@@ -341,8 +343,8 @@ class Glm4vVisionAttention(nn.Module):
) )
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) ... -> b s ...", "(b s) h d -> s b (h d)",
b=batch_size) b=batch_size).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
...@@ -361,6 +363,8 @@ class Glm4vVisionAttention(nn.Module): ...@@ -361,6 +363,8 @@ class Glm4vVisionAttention(nn.Module):
output_i = rearrange(output_i, "b h s d -> b s h d ") output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...@@ -371,7 +375,6 @@ class Glm4vVisionAttention(nn.Module): ...@@ -371,7 +375,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = xops.memory_efficient_attention_forward( context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None) q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer, context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous() "b s h d -> s b (h d)").contiguous()
......
...@@ -377,8 +377,10 @@ class Qwen2VisionAttention(nn.Module): ...@@ -377,8 +377,10 @@ class Qwen2VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v)) for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) # [2 * b, s, heads, head_dim]
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA: if self.attn_backend == _Backend.ROCM_AITER_FA:
...@@ -402,8 +404,8 @@ class Qwen2VisionAttention(nn.Module): ...@@ -402,8 +404,8 @@ class Qwen2VisionAttention(nn.Module):
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
"(b s) ... -> b s ...", "(b s) h d -> s b (h d)",
b=batch_size) b=batch_size).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
...@@ -422,6 +424,8 @@ class Qwen2VisionAttention(nn.Module): ...@@ -422,6 +424,8 @@ class Qwen2VisionAttention(nn.Module):
output_i = rearrange(output_i, "b h s d -> b s h d ") output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment