Unverified Commit 3cfa63ad authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU]fix Kimi-VL-A3B-thinking on xpu (#29309)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 4d6afcad
......@@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
......@@ -106,10 +109,10 @@ def multihead_attention(
q,
k,
v,
q_cu_seqlens,
k_cu_seqlens,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q=q_cu_seqlens,
cu_seqlens_k=k_cu_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
......@@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module):
"""
def __init__(
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
self,
dim: int,
max_height: int,
max_width: int,
theta_base=10000,
device=current_platform.device_type,
):
super().__init__()
self.dim = dim
......@@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module):
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
if is_flash_attn_2_available():
if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.norm0 = nn.LayerNorm(hidden_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment