Unverified Commit 3c3d7672 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix mla cpu - missing 3 required positional arguments (#17494)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent 13cf6b62
...@@ -177,7 +177,7 @@ class ipex_ops: ...@@ -177,7 +177,7 @@ class ipex_ops:
out: torch.Tensor, out: torch.Tensor,
seqlen_q: torch.Tensor, seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor, seqlen_k: torch.Tensor,
alibi_slopes: torch.Tensor, alibi_slopes: Optional[torch.Tensor],
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
pdropout: float, pdropout: float,
...@@ -193,6 +193,8 @@ class ipex_ops: ...@@ -193,6 +193,8 @@ class ipex_ops:
if ipex.__version__.endswith("cpu"): if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0: if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap") raise ValueError("IPEX CPU does not support logits_soft_cap")
assert alibi_slopes is None
assert window_size_left < 0 and window_size_right < 0
ipex.llm.functional.varlen_attention(query.contiguous(), ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(), key.contiguous(),
value.contiguous(), out, value.contiguous(), out,
......
...@@ -273,6 +273,9 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): ...@@ -273,6 +273,9 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
return_softmax=False, return_softmax=False,
gen_=None, gen_=None,
logits_soft_cap=0.0, logits_soft_cap=0.0,
window_size_left=-1,
window_size_right=-1,
alibi_slopes=None,
) )
# remove padding # remove padding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment