Unverified Commit c3e9d506 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Use `apply_rotary_emb` from vllm_flash_attn for Qwen2-VL vision RoPE (#17726)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 822de7fb
...@@ -297,13 +297,8 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -297,13 +297,8 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v)) for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
q = apply_rotary_pos_emb_vision(q, k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
......
...@@ -64,7 +64,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -64,7 +64,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import ( from vllm.transformers_utils.processor import (
...@@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor, ...@@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
use_flash_attn=False) -> torch.Tensor:
t_ = t.float() t_ = t.float()
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn: if current_platform.is_cuda():
from flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t) output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment