Unverified Commit 20c90be2 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature] Support FA3 backend for MLA (#4831)

parent ec3ee028
...@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -58,6 +60,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -58,6 +60,9 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"])
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata to cache repetitive calculations."""
...@@ -117,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -117,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( cache_loc = (
layer, cache_loc, k, v, layer.k_scale, layer.v_scale forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
) )
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Use precomputed metadata # Use precomputed metadata
metadata = self.forward_metadata metadata = self.forward_metadata
# # Use Flash Attention for prefill
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive # here is two side inclusive
...@@ -142,36 +154,72 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -142,36 +154,72 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
page_table = metadata.page_table page_table = metadata.page_table
o = flash_attn_with_kvcache( # # Use Flash Attention for prefill
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), if not self.use_mla:
k_cache=key_cache, # Do multi-head attention
v_cache=value_cache, kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
page_table=page_table, key_cache, value_cache = kv_cache[0], kv_cache[1]
cache_seqlens=metadata.cache_seqlens_int32, key_cache = key_cache.view(
cu_seqlens_q=metadata.cu_seqlens_q, -1, self.page_size, layer.tp_k_head_num, layer.head_dim
cu_seqlens_k_new=metadata.cu_seqlens_k, )
max_seqlen_q=metadata.max_seq_len_q, value_cache = value_cache.view(
softmax_scale=layer.scaling, -1, self.page_size, layer.tp_v_head_num, layer.head_dim
causal=True, )
window_size=window_size, o = flash_attn_with_kvcache(
softcap=layer.logit_cap, q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_descale=layer.k_scale, k_cache=key_cache,
v_descale=layer.v_scale, v_cache=value_cache,
) page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode( def forward_decode(
self, self,
...@@ -184,24 +232,29 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -184,24 +232,29 @@ class FlashAttentionBackend(AttentionBackend):
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention using precomputed metadata.""" """Forward pass with FlashAttention using precomputed metadata."""
# Save KV cache if needed # Save KV cache if needed
if k is not None and v is not None and save_kv_cache: if k is not None:
cache_loc = ( assert v is not None
forward_batch.out_cache_loc if save_kv_cache:
if not layer.is_cross_attention cache_loc = (
else forward_batch.encoder_out_cache_loc forward_batch.out_cache_loc
) if not layer.is_cross_attention
forward_batch.token_to_kv_pool.set_kv_buffer( else forward_batch.encoder_out_cache_loc
layer, cache_loc, k, v, layer.k_scale, layer.v_scale )
) if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Get KV cache
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
# Use precomputed metadata # Use precomputed metadata
metadata = self.forward_metadata metadata = self.forward_metadata
# Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive # here is two side inclusive
...@@ -210,33 +263,79 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -210,33 +263,79 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None if layer.sliding_window_size is not None
else (-1, -1) else (-1, -1)
) )
# Run attention with precomputed values
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
page_table = metadata.page_table page_table = metadata.page_table
o = flash_attn_with_kvcache( if not self.use_mla:
q=q_reshaped, # Do multi-head attention
k_cache=key_cache,
v_cache=value_cache, # Get KV cache
page_table=page_table, kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
cache_seqlens=metadata.cache_seqlens_int32, key_cache, value_cache = kv_cache[0], kv_cache[1]
cu_seqlens_q=metadata.cu_seqlens_q, key_cache = key_cache.view(
cu_seqlens_k_new=metadata.cu_seqlens_k, -1, self.page_size, layer.tp_k_head_num, layer.head_dim
max_seqlen_q=1, )
softmax_scale=layer.scaling, value_cache = value_cache.view(
causal=True, -1, self.page_size, layer.tp_v_head_num, layer.head_dim
window_size=window_size, )
softcap=layer.logit_cap,
k_descale=layer.k_scale, # Pre-reshape query tensor
v_descale=layer.v_scale, q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) # Run attention with precomputed values
o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
"""Initialize CUDA graph state for the attention backend. """Initialize CUDA graph state for the attention backend.
...@@ -286,7 +385,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -286,7 +385,6 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, : req_pool_indices, :
] ]
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
......
...@@ -230,6 +230,10 @@ class ModelRunner: ...@@ -230,6 +230,10 @@ class ModelRunner:
elif server_args.enable_flashmla: elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.") logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla" server_args.attention_backend = "flashmla"
elif server_args.attention_backend == "fa3":
logger.info(
f"MLA optimization is turned on. Use flash attention 3 backend."
)
else: else:
logger.info("MLA optimization is turned on. Use triton backend.") logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
...@@ -879,7 +883,7 @@ class ModelRunner: ...@@ -879,7 +883,7 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
logger.warning( logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported." "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
) )
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend, FlashAttentionBackend,
......
...@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.flashinfer_mla_disable_ragged = global_server_args_dict[ self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged" "flashinfer_mla_disable_ragged"
] ]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool: def no_absorb(self, forward_batch: ForwardBatch) -> bool:
...@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0 and sum(forward_batch.extend_prefix_lens_cpu) == 0
) )
elif self.attention_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment