Commit 4ba4b755 authored by zhuwenwen's avatar zhuwenwen
Browse files

update apply_rotary_emb for z100l&k100 (keye)

parent 62fe9a48
......@@ -55,6 +55,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC
logger = init_logger(__name__)
......@@ -331,10 +332,11 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if not current_platform.is_rocm():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
else:
from flash_attn.layers.rotary import apply_rotary_emb
if SUPPORT_TC:
if not current_platform.is_rocm():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
else:
from flash_attn.layers.rotary import apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment