"tests/python/common/test_heterograph-misc.py" did not exist on "a9ffb59e662398d9c40ed6ebe288e5542715e0dd"
Unverified Commit 57de7c6b authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: use fa3 mla by default on hopper (#5210)


Co-authored-by: default avataryundai424 <yundai424@gmail.com>
Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
parent 115ae2e7
...@@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -325,7 +325,7 @@ class FlashAttentionBackend(AttentionBackend):
batch_size = len(seqlens_in_batch) batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode # Draft Decode
if forward_batch.spec_info is not None: if forward_batch.spec_info is not None:
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = (
...@@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -527,7 +527,9 @@ class FlashAttentionBackend(AttentionBackend):
else (-1, -1) else (-1, -1)
) )
k_descale, v_descale = None, None k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto": # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape) v_descale = layer.v_scale.expand(descale_shape)
...@@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -670,10 +672,13 @@ class FlashAttentionBackend(AttentionBackend):
causal = not layer.is_cross_attention causal = not layer.is_cross_attention
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto": if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) if layer.k_scale is not None:
k_descale = layer.k_scale.expand(descale_shape) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
v_descale = layer.v_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype) q = q.to(self.kv_cache_dtype)
if not self.use_mla: if not self.use_mla:
...@@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -834,7 +839,7 @@ class FlashAttentionBackend(AttentionBackend):
"""Initialize forward metadata for capturing CUDA graph.""" """Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
device = seq_lens.device device = seq_lens.device
if forward_mode.is_decode(): if forward_mode.is_decode_or_idle():
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
...@@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -937,7 +942,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs] seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs] req_pool_indices = req_pool_indices[:bs]
if forward_mode.is_decode(): if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None: if spec_info is not None:
......
...@@ -80,6 +80,7 @@ from sglang.srt.utils import ( ...@@ -80,6 +80,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_hopper_with_cuda_12_3,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
...@@ -245,7 +246,16 @@ class ModelRunner: ...@@ -245,7 +246,16 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton" "flashinfer" if is_flashinfer_available() else "triton"
) )
else: else:
server_args.attention_backend = "triton" if is_hopper_with_cuda_12_3():
if server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
and server_args.speculative_eagle_topk == 1
):
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
logger.info( logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default." f"Attention backend not set. Use {server_args.attention_backend} backend by default."
) )
...@@ -263,6 +273,16 @@ class ModelRunner: ...@@ -263,6 +273,16 @@ class ModelRunner:
else: else:
raise ValueError(f"MLA optimization not supported on CPU.") raise ValueError(f"MLA optimization not supported on CPU.")
if (
server_args.attention_backend == "fa3"
and server_args.kv_cache_dtype == "fp8_e5m2"
):
logger.warning(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity: if server_args.enable_double_sparsity:
logger.info( logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph." "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
...@@ -889,9 +909,6 @@ class ModelRunner: ...@@ -889,9 +909,6 @@ class ModelRunner:
"FlashAttention v3 Backend requires SM>=90. " "FlashAttention v3 Backend requires SM>=90. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
logger.warning(
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
)
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend, FlashAttentionBackend,
) )
......
...@@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim): ...@@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim):
else: else:
# Use topk for efficiency with larger k values # Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim) return torch.topk(values, topk, dim=dim)
def is_hopper_with_cuda_12_3():
if not is_cuda():
return False
is_hopper = torch.cuda.get_device_capability()[0] == 9
cuda_version = torch.version.cuda.split(".")
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
return is_hopper and is_cuda_compatible
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment