Unverified Commit 639402f5 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

Support FP8 KVCache on XPU (#37731)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 0f7be0f2
...@@ -35,6 +35,7 @@ steps: ...@@ -35,6 +35,7 @@ steps:
python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp && python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp &&
python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN && python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN &&
python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8 && python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8 &&
python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --kv-cache-dtype fp8 &&
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 && python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 && python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel' python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel'
......
...@@ -258,6 +258,9 @@ class xpu_ops: ...@@ -258,6 +258,9 @@ class xpu_ops:
# alibi_slopes = alibi_slopes, # alibi_slopes = alibi_slopes,
# softcap=softcap, # softcap=softcap,
return_softmax_lse=return_softmax_lse, return_softmax_lse=return_softmax_lse,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
@staticmethod @staticmethod
......
...@@ -166,12 +166,18 @@ def is_fa_version_supported(fa_version: int) -> bool: ...@@ -166,12 +166,18 @@ def is_fa_version_supported(fa_version: int) -> bool:
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
if current_platform.is_xpu():
return True
return ( return (
get_flash_attn_version() == 3 get_flash_attn_version() == 3
and current_platform.is_device_capability_family(90) and current_platform.is_device_capability_family(90)
) )
def flash_attn_supports_quant_query_input() -> bool:
return not current_platform.is_xpu()
def flash_attn_supports_sinks() -> bool: def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu(): if current_platform.is_xpu():
return True return True
......
...@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import ( ...@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8, flash_attn_supports_fp8,
flash_attn_supports_quant_query_input,
get_flash_attn_version, get_flash_attn_version,
is_fa_version_supported, is_fa_version_supported,
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
...@@ -656,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -656,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer" "heads in the layer"
) )
self.supports_quant_query_input = True self.supports_quant_query_input = flash_attn_supports_quant_query_input()
vllm_config = get_current_vllm_config_or_none() vllm_config = get_current_vllm_config_or_none()
dcp_a2a = ( dcp_a2a = (
...@@ -757,7 +758,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -757,7 +758,11 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape) q_descale = (
layer._q_scale.expand(descale_shape)
if self.supports_quant_query_input
else None
)
k_descale = layer._k_scale.expand(descale_shape) k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape) v_descale = layer._v_scale.expand(descale_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment