Unverified Commit 0a56bcc0 authored by Jani Monoses's avatar Jani Monoses Committed by GitHub
Browse files

[Bugfix][Hardware][CPU] Enable Gemma2 with SDPA on CPU backend (#11169)

parent 0920ab91
...@@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad, print_warning_once
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
...@@ -395,7 +395,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -395,7 +395,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise ValueError( raise ValueError(
"Torch SPDA does not support block-sparse attention.") "Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None: if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.") print_warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -619,7 +620,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -619,7 +620,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value[None, :, start_kv:end_kv, :], value[None, :, start_kv:end_kv, :],
attn_mask=mask, attn_mask=mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=causal_attn and not self.need_mask, is_causal=causal_attn and mask is None,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv start_q, start_kv = end_q, end_kv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment