Commit 4d53d14c authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash-attn interface to support keye

parent a5f106eb
...@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, ...@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, is_pp_missing_parameter, init_vllm_registered_model, is_pp_missing_parameter,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -330,7 +331,10 @@ def apply_rotary_pos_emb_flashatt( ...@@ -330,7 +331,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb if not current_platform.is_rocm():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
else:
from flash_attn.layers.rotary import apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment