"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "ac2f3f7fee93cf9cd97c0078e362feab7b6c8299"
Unverified Commit af8486de authored by Sanju C Sudhakaran's avatar Sanju C Sudhakaran Committed by GitHub
Browse files

[Hardware][Intel-Gaudi] Enable FusedSDPA support for Intel Gaudi (HPU)

parent 4c3aac51
...@@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
import vllm_hpu_extension.ops as ops import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
...@@ -137,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -137,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true'] '0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa: if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \ assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!' 'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes: if head_size not in suppored_head_sizes:
...@@ -227,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -227,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
matmul_qk_op=self.matmul_qk, matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax, softmax_op=self.softmax,
matmul_av_op=self.matmul_av, matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
) )
output = out.reshape(batch_size, seq_len, hidden_size) output = out.reshape(batch_size, seq_len, hidden_size)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment