Unverified Commit e98d9346 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[1/2] Support FA4 for MHA Prefill in sgl-kernel (#10940)

parent 0c917410
......@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
[project]
name = "sgl-kernel"
version = "0.3.12"
version = "0.3.13"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
......@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
[project]
name = "sgl-kernel"
version = "0.3.12"
version = "0.3.13"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
......@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.3.12"
version = "0.3.13"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.10"
......
......@@ -153,7 +153,43 @@ def flash_attn_with_kvcache(
normalization factor).
"""
if ver == 4:
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
assert (
flash_attn_varlen_func_v4 is not None
), "FA4 is not available, please check your installation."
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
assert (
k is None and v is None
), "FA4 does not support updating KV cache in-place."
assert (
rotary_cos is None
and rotary_sin is None
and rotary_interleaved is None
and rotary_seqlens is None
), "FA4 does not support rotary embedding."
assert (
cache_batch_idx is None and cache_leftpad is None
), "FA4 does not support non-consecutive batch indices or left padding."
assert (
q_descale is None and k_descale is None and v_descale is None
), "FA4 does not support descale."
if window_size == (-1, -1):
window_size = (None, None)
return flash_attn_varlen_func_v4(
q=q,
k=k_cache,
v=v_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
pack_gqa=pack_gqa,
return_softmax_lse=return_softmax_lse,
learnable_sink=sinks,
page_table=page_table,
)
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
......
__version__ = "0.3.12"
__version__ = "0.3.13"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment