Unverified Commit 84810da4 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add Cutlass MLA attention backend (#5390)

parent 40d9b8ac
...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma
## Kernel backend ## Kernel backend
* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. * `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend.
* `sampling_backend`: The backend for sampling. * `sampling_backend`: The backend for sampling.
## Constrained Decoding ## Constrained Decoding
......
from __future__ import annotations
"""
Support attention backend for Cutlass MLA.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInfo
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
# Cutlass MLA only supports pagesize=128
PAGE_SIZE = 128
@dataclass
class CutlassMLADecodeMetadata:
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
workspace: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.workspace = workspace
self.block_kv_indices = block_kv_indices
class CutlassMLABackend(FlashInferMLAAttnBackend):
"""Cutlass attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__(
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
)
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata = CutlassMLADecodeMetadata(
workspace,
block_kv_indices,
)
else:
super().init_forward_metadata(forward_batch)
else:
super().init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices
workspace_size = cutlass_mla_get_workspace_size(
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata = CutlassMLADecodeMetadata(
self.cuda_graph_mla_workspace,
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
o = cutlass_mla_decode(
q_nope_and_q_pe=reshape_q,
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices,
workspace=self.forward_metadata.workspace,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton( ...@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr, kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr, req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr,
PAGED_SIZE: tl.constexpr = 64,
): ):
PAGED_SIZE: tl.constexpr = 64
BLOCK_SIZE: tl.constexpr = 4096 BLOCK_SIZE: tl.constexpr = 4096
NUM_PAGE_PER_BLOCK: tl.constexpr = 64 NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
......
...@@ -1515,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1515,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3" or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
): ):
seq_lens_cpu = self.seq_lens.cpu() seq_lens_cpu = self.seq_lens.cpu()
else: else:
......
...@@ -271,6 +271,7 @@ class ModelRunner: ...@@ -271,6 +271,7 @@ class ModelRunner:
"fa3", "fa3",
"triton", "triton",
"flashmla", "flashmla",
"cutlass_mla",
]: ]:
logger.info( logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend." f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
...@@ -926,6 +927,12 @@ class ModelRunner: ...@@ -926,6 +927,12 @@ class ModelRunner:
) )
self.attn_backend = FlashAttentionBackend(self) self.attn_backend = FlashAttentionBackend(self)
elif self.server_args.attention_backend == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
)
self.attn_backend = CutlassMLABackend(self)
else: else:
raise ValueError( raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}" f"Invalid attention backend: {self.server_args.attention_backend}"
......
...@@ -256,6 +256,12 @@ class ServerArgs: ...@@ -256,6 +256,12 @@ class ServerArgs:
) )
self.page_size = 64 self.page_size = 64
if self.attention_backend == "cutlass_mla":
logger.warning(
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
# Set cuda graph max batch size # Set cuda graph max batch size
if self.cuda_graph_max_bs is None: if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
...@@ -823,7 +829,14 @@ class ServerArgs: ...@@ -823,7 +829,14 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--attention-backend", "--attention-backend",
type=str, type=str,
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"], choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
) )
......
...@@ -78,6 +78,7 @@ def cutlass_mla_decode( ...@@ -78,6 +78,7 @@ def cutlass_mla_decode(
assert len(page_table.shape) == 2 assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape B_block_table, block_num = page_table.shape
assert B_block_table == B_q assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0 assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8 # TODO(kaixih@nvidia): support fp8
...@@ -109,6 +110,8 @@ def cutlass_mla_decode( ...@@ -109,6 +110,8 @@ def cutlass_mla_decode(
def cutlass_mla_get_workspace_size( def cutlass_mla_get_workspace_size(
max_seq_len: int, num_batches: int, sm_count: int = 0 max_seq_len: int, num_batches: int, sm_count: int = 0
) -> int: ) -> int:
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count max_seq_len, num_batches, sm_count
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment