Unverified Commit fe9c3d6c authored by Julian Huang's avatar Julian Huang Committed by GitHub
Browse files

[TurboQuant] enable FA3/FA4 for prefill paths (#40092)


Signed-off-by: default avatar墨楼 <huangzhilin.hzl@antgroup.com>
Co-authored-by: default avatar墨楼 <huangzhilin.hzl@antgroup.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarCodex <codex@openai.com>
parent ccaf5ffa
......@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.78
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_k3v4_nc --max-model-len 4096"
......@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_k8v4 --max-model-len 4096"
......@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_3bit_nc --max-model-len 4096"
......@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_4bit_nc --max-model-len 4096"
......@@ -255,11 +255,16 @@ class FlashAttentionMetadata:
def _get_sliding_window_configs(
vllm_config: VllmConfig,
) -> set[tuple[int, int] | None]:
"""Get the set of all sliding window configs used in the model."""
"""Get the set of all sliding window configs used in the model.
Only inspects FlashAttentionImpl layers. Other backends (e.g.
TurboQuant, MLA) use their own metadata builders and are skipped.
"""
sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
if not isinstance(layer.impl, FlashAttentionImpl):
continue
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs
......
......@@ -39,6 +39,7 @@ from vllm.v1.attention.backend import (
MultipleOf,
)
from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
......@@ -271,6 +272,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1
# Detect flash-attn version (FA2/3/4) for prefill paths.
self.fa_version = get_flash_attn_version(head_size=head_size)
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode).
vllm_config = get_current_vllm_config()
......@@ -278,6 +282,43 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)
def _flash_attn_varlen(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
) -> torch.Tensor:
# fa_utils.get_flash_attn_version() returns None on backends that
# should not pass an explicit fa_version kwarg.
if self.fa_version is None:
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
)
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
fa_version=self.fa_version,
)
def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrix, midpoints).
......@@ -503,7 +544,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
return flash_attn_varlen_func(
return self._flash_attn_varlen(
q=query,
k=key,
v=value,
......@@ -511,8 +552,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
)
# Continuation or no flash_attn: per-request attention.
......@@ -552,7 +591,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if _HAS_FLASH_ATTN:
_cu_2[1] = q_len
cu = _cu_2
out = flash_attn_varlen_func(
out = self._flash_attn_varlen(
q=q_seq,
k=k_seq,
v=v_seq,
......@@ -560,8 +599,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=cu,
max_seqlen_q=q_len,
max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
)
else:
q_t = q_seq.transpose(0, 1).contiguous()
......@@ -726,7 +763,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if _HAS_FLASH_ATTN:
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
return flash_attn_varlen_func(
return self._flash_attn_varlen(
q=query,
k=k_full,
v=v_full,
......@@ -734,8 +771,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# SDPA fallback: expand KV for GQA, build causal mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment