Unverified Commit 0753ef83 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

[Auto Sync] Update flashattention_backend.py (20250922) (#10762)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarGordon Gustafson <ggustafson@together.ai>
parent 662393f2
...@@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
if layer.k_scale is not None: if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment