Unverified Commit fe9c3d6c authored by Julian Huang's avatar Julian Huang Committed by GitHub
Browse files

[TurboQuant] enable FA3/FA4 for prefill paths (#40092)


Signed-off-by: default avatar墨楼 <huangzhilin.hzl@antgroup.com>
Co-authored-by: default avatar墨楼 <huangzhilin.hzl@antgroup.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarCodex <codex@openai.com>
parent ccaf5ffa
...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" ...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.78 accuracy_threshold: 0.78
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096" server_args: "--kv-cache-dtype turboquant_k3v4_nc --max-model-len 4096"
...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" ...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80 accuracy_threshold: 0.80
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096" server_args: "--kv-cache-dtype turboquant_k8v4 --max-model-len 4096"
...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" ...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.75 accuracy_threshold: 0.75
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096" server_args: "--kv-cache-dtype turboquant_3bit_nc --max-model-len 4096"
...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" ...@@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80 accuracy_threshold: 0.80
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096" server_args: "--kv-cache-dtype turboquant_4bit_nc --max-model-len 4096"
...@@ -255,11 +255,16 @@ class FlashAttentionMetadata: ...@@ -255,11 +255,16 @@ class FlashAttentionMetadata:
def _get_sliding_window_configs( def _get_sliding_window_configs(
vllm_config: VllmConfig, vllm_config: VllmConfig,
) -> set[tuple[int, int] | None]: ) -> set[tuple[int, int] | None]:
"""Get the set of all sliding window configs used in the model.""" """Get the set of all sliding window configs used in the model.
Only inspects FlashAttentionImpl layers. Other backends (e.g.
TurboQuant, MLA) use their own metadata builders and are skipped.
"""
sliding_window_configs: set[tuple[int, int] | None] = set() sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention) layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values(): for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl) if not isinstance(layer.impl, FlashAttentionImpl):
continue
sliding_window_configs.add(layer.impl.sliding_window) sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs return sliding_window_configs
......
...@@ -39,6 +39,7 @@ from vllm.v1.attention.backend import ( ...@@ -39,6 +39,7 @@ from vllm.v1.attention.backend import (
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
) )
from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.backends.utils import split_decodes_and_prefills
...@@ -271,6 +272,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -271,6 +272,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8) self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1 self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1
# Detect flash-attn version (FA2/3/4) for prefill paths.
self.fa_version = get_flash_attn_version(head_size=head_size)
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph, # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode). # and benchmarks show no regression vs dynamic in eager mode).
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
...@@ -278,6 +282,43 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -278,6 +282,43 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
) )
def _flash_attn_varlen(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
) -> torch.Tensor:
# fa_utils.get_flash_attn_version() returns None on backends that
# should not pass an explicit fa_version kwarg.
if self.fa_version is None:
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
)
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
fa_version=self.fa_version,
)
def _ensure_on_device(self, layer, device): def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrix, midpoints). """One-time derivation of TQ buffers (rotation matrix, midpoints).
...@@ -503,7 +544,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -503,7 +544,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# max_query_len == max_seq_len means no request has prior cached KV. # max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync. # Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
return flash_attn_varlen_func( return self._flash_attn_varlen(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -511,8 +552,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -511,8 +552,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=attn_metadata.query_start_loc, cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_query_len, max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
) )
# Continuation or no flash_attn: per-request attention. # Continuation or no flash_attn: per-request attention.
...@@ -552,7 +591,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -552,7 +591,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if _HAS_FLASH_ATTN: if _HAS_FLASH_ATTN:
_cu_2[1] = q_len _cu_2[1] = q_len
cu = _cu_2 cu = _cu_2
out = flash_attn_varlen_func( out = self._flash_attn_varlen(
q=q_seq, q=q_seq,
k=k_seq, k=k_seq,
v=v_seq, v=v_seq,
...@@ -560,8 +599,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -560,8 +599,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=cu, cu_seqlens_k=cu,
max_seqlen_q=q_len, max_seqlen_q=q_len,
max_seqlen_k=q_len, max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
) )
else: else:
q_t = q_seq.transpose(0, 1).contiguous() q_t = q_seq.transpose(0, 1).contiguous()
...@@ -726,7 +763,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -726,7 +763,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if _HAS_FLASH_ATTN: if _HAS_FLASH_ATTN:
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
return flash_attn_varlen_func( return self._flash_attn_varlen(
q=query, q=query,
k=k_full, k=k_full,
v=v_full, v=v_full,
...@@ -734,8 +771,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -734,8 +771,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len, max_seqlen_q=q_len,
max_seqlen_k=seq_len, max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
) )
else: else:
# SDPA fallback: expand KV for GQA, build causal mask # SDPA fallback: expand KV for GQA, build causal mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment