Commit 6c0a30bc authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash-attn interface of apply_rotary_emb

parent b0dfa004
...@@ -40,6 +40,8 @@ from vllm.platforms import current_platform ...@@ -40,6 +40,8 @@ from vllm.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......
...@@ -9,7 +9,9 @@ from vllm.platforms import current_platform ...@@ -9,7 +9,9 @@ from vllm.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
# common functions # common functions
def rotate_neox(x: torch.Tensor) -> torch.Tensor: def rotate_neox(x: torch.Tensor) -> torch.Tensor:
......
...@@ -115,6 +115,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, ...@@ -115,6 +115,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor,
apply_rotary_emb = apply_rotary_emb_torch apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t) output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output return output
......
...@@ -244,6 +244,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, ...@@ -244,6 +244,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor,
apply_rotary_emb = apply_rotary_emb_torch apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t) output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment