Unverified Commit a53fe428 authored by lukec's avatar lukec Committed by GitHub
Browse files

Support FlashMLA backend (#4472)


Co-authored-by: default avataryinfan98 <1106310035@qq.com>
parent 1b859295
from __future__ import annotations
"""
Support attention backend for flashMLA.
Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
#TODO
Support FlashMLA decode with cudagraph
Enable speculative sampling in FlashMLA
Integrate FA3 prefill
"""
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
# FlashMLA only supports pagesize=64
PAGE_SIZE = 64
class FlashMLABackend(FlashInferMLAAttnBackend):
"""Flashinfer attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__(
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
)
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
flashmla_index = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.indices_updater_decode.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
flashmla_index,
self.indices_updater_decode.req_to_token.size(1),
flashmla_index.size(1),
max_seqlen_pad,
)
mla_metadata, mla_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
1 * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=flashmla_index,
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=mla_metadata,
num_splits=mla_splits,
softmax_scale=layer.scaling,
causal=False,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton( ...@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton(
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid) req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid) kv_indices_offset = tl.load(kv_indptr + pid)
...@@ -37,3 +38,56 @@ def create_flashinfer_kv_indices_triton( ...@@ -37,3 +38,56 @@ def create_flashinfer_kv_indices_triton(
mask=mask, mask=mask,
) )
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@triton.jit
def create_flashmla_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr,
max_pagesize: tl.constexpr,
):
PAGED_SIZE: tl.constexpr = 64
BLOCK_SIZE: tl.constexpr = 4096
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_pages_loop):
paged_offset = (
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
) * PAGED_SIZE
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
mask = paged_offset <= num_paged * PAGED_SIZE
mask_out = paged_offset_out <= num_paged
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ paged_offset,
mask=mask,
)
tl.store(
kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
data // PAGED_SIZE,
mask=mask_out,
)
...@@ -71,6 +71,7 @@ global_server_args_dict = { ...@@ -71,6 +71,7 @@ global_server_args_dict = {
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
} }
...@@ -1273,7 +1274,10 @@ class ScheduleBatch: ...@@ -1273,7 +1274,10 @@ class ScheduleBatch:
def get_model_worker_batch(self) -> ModelWorkerBatch: def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle(): if self.forward_mode.is_decode_or_idle():
if global_server_args_dict["enable_flashinfer_mla"]: if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
):
decode_seq_lens = self.seq_lens.cpu() decode_seq_lens = self.seq_lens.cpu()
else: else:
decode_seq_lens = None decode_seq_lens = None
......
...@@ -149,6 +149,7 @@ class ModelRunner: ...@@ -149,6 +149,7 @@ class ModelRunner:
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
...@@ -223,6 +224,9 @@ class ModelRunner: ...@@ -223,6 +224,9 @@ class ModelRunner:
"MLA optimization is turned on. Use flashinfer mla backend." "MLA optimization is turned on. Use flashinfer mla backend."
) )
server_args.attention_backend = "flashinfer_mla" server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
else: else:
logger.info("MLA optimization is turned on. Use triton backend.") logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
...@@ -840,6 +844,10 @@ class ModelRunner: ...@@ -840,6 +844,10 @@ class ModelRunner:
) )
self.attn_backend = FlashInferMLAAttnBackend(self) self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
self.attn_backend = FlashMLABackend(self)
else: else:
raise ValueError( raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}" f"Invalid attention backend: {self.server_args.attention_backend}"
......
...@@ -173,6 +173,7 @@ class ServerArgs: ...@@ -173,6 +173,7 @@ class ServerArgs:
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False enable_flashinfer_mla: bool = False
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
...@@ -227,6 +228,8 @@ class ServerArgs: ...@@ -227,6 +228,8 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0 assert self.chunked_prefill_size % self.page_size == 0
if self.enable_flashmla is True:
assert self.page_size == 64, "FlashMLA only support page_size=64"
# Set cuda graph max batch size # Set cuda graph max batch size
if self.cuda_graph_max_bs is None: if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
...@@ -753,6 +756,11 @@ class ServerArgs: ...@@ -753,6 +756,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable FlashInfer MLA optimization", help="Enable FlashInfer MLA optimization",
) )
parser.add_argument(
"--enable-flashmla",
action="store_true",
help="Enable FlashMLA decode optimization",
)
parser.add_argument( parser.add_argument(
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
action="store_true", action="store_true",
......
...@@ -182,6 +182,12 @@ def main(args, server_args): ...@@ -182,6 +182,12 @@ def main(args, server_args):
"--enable-flashinfer-mla", "--enable-flashinfer-mla",
] ]
) )
if server_args.enable_flashmla:
other_args.extend(
[
"--enable-flashmla",
]
)
if server_args.quantization: if server_args.quantization:
other_args.extend( other_args.extend(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment