You need to sign in or sign up before continuing.
Unverified Commit b6944f97 authored by lukec's avatar lukec Committed by GitHub
Browse files

Support FlashMLA backend cuda graph (#4514)


Co-authored-by: default avataryinfan98 <1106310035@qq.com>
Co-authored-by: default avatarHongbosherlock <hongbosherlock@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@163.com>
parent f44db16c
from __future__ import annotations from __future__ import annotations
""" """
Support attention backend for flashMLA. Support attention backend for FlashMLA.
Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
#TODO #TODO
Support FlashMLA decode with cudagraph
Enable speculative sampling in FlashMLA Enable speculative sampling in FlashMLA
Integrate FA3 prefill
""" """
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -28,10 +25,30 @@ if TYPE_CHECKING: ...@@ -28,10 +25,30 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInfo
# FlashMLA only supports pagesize=64 # FlashMLA only supports pagesize=64
PAGE_SIZE = 64 PAGE_SIZE = 64
# TODO The current setup is hard-coded and will be changed after integrating with MTP.
Q_LEN = 1
@dataclass
class FlashMLADecodeMetadata:
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices
class FlashMLABackend(FlashInferMLAAttnBackend): class FlashMLABackend(FlashInferMLAAttnBackend):
...@@ -58,6 +75,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -58,6 +75,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_local_heads = ( self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.forward_metadata: Union[FlashMLADecodeMetadata] = None
self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
...@@ -67,6 +85,163 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -67,6 +85,163 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
else:
super().init_forward_metadata(forward_batch)
else:
super().init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = FlashMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
seq_lens = seq_lens[:bs]
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def forward_decode( def forward_decode(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -88,39 +263,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -88,39 +263,18 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
v, v,
) )
bs = forward_batch.batch_size bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
flashmla_index = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.indices_updater_decode.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
flashmla_index,
self.indices_updater_decode.req_to_token.size(1),
flashmla_index.size(1),
max_seqlen_pad,
)
mla_metadata, mla_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
1 * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=reshape_q, q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=flashmla_index, block_table=self.forward_metadata.block_kv_indices,
cache_seqlens=forward_batch.seq_lens.to(torch.int32), cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=mla_metadata, tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=mla_splits, num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=False,
) )
......
...@@ -49,7 +49,6 @@ def create_flashmla_kv_indices_triton( ...@@ -49,7 +49,6 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr, kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr, req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr,
max_pagesize: tl.constexpr,
): ):
PAGED_SIZE: tl.constexpr = 64 PAGED_SIZE: tl.constexpr = 64
BLOCK_SIZE: tl.constexpr = 4096 BLOCK_SIZE: tl.constexpr = 4096
......
...@@ -232,7 +232,10 @@ class ServerArgs: ...@@ -232,7 +232,10 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0 assert self.chunked_prefill_size % self.page_size == 0
if self.enable_flashmla is True: if self.enable_flashmla is True:
assert self.page_size == 64, "FlashMLA only support page_size=64" logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
# Set cuda graph max batch size # Set cuda graph max batch size
if self.cuda_graph_max_bs is None: if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment