from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Tuple, Union import torch import triton from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode try: from flash_mla import ( flash_mla_with_kvcache, flash_mla_with_kvcache_quantization, get_mla_metadata ) _has_flash_mla = True except Exception: try: from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata ) _has_flash_mla = False except Exception: raise ImportError( "Can not import FlashMLA。Please perform the following operations to use flashmla:\n" " pip install flash-mla\n" " or\n" " pip install vllm" ) PAGE_SIZE = 64 # 强制64 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInput @dataclass class VllmMLADecodeMetadata: flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None num_splits: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None class DCUMLABackend(AttentionBackend): def __init__( self, model_runner: "ModelRunner", skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, kv_last_page_len_buf: Optional[torch.Tensor] = None, ): super().__init__() if model_runner.server_args.page_size != PAGE_SIZE: raise ValueError( f"dcu_mla backend requires page_size={PAGE_SIZE}, " f"but got the {model_runner.server_args.page_size}" ) self.num_q_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.req_to_token = model_runner.req_to_token_pool.req_to_token self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.v_head_dim = model_runner.model_config.v_head_dim self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.device = model_runner.device self.max_context_len = model_runner.model_config.context_len self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.forward_metadata: Union[VllmMLADecodeMetadata] = None self.skip_prefill = skip_prefill if not skip_prefill: # 先用triton backend,后面考虑替换 # from sglang.srt.layers.attention.triton_backend import TritonAttnBackend # self.triton_backend = TritonAttnBackend( # model_runner, # skip_prefill=False, # kv_indptr_buf=kv_indptr_buf, # ) # prefill改用flash attn from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend self.flashattn_backend = FlashAttentionBackend( model_runner, skip_prefill=False, ) def _build_decode_metadata( self, forward_batch: ForwardBatch, seq_lens: torch.Tensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: bs = forward_batch.batch_size max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) # 参考vllm官方博客分页 block_kv_indices = torch.full( (bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device ) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, forward_batch.req_pool_indices, seq_lens, None, block_kv_indices, self.req_to_token.stride(0), max_seqlen_pad, ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), self.num_q_heads, 1 ) return (mla_metadata, num_splits), num_splits, block_kv_indices def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): # decode用flashmla (mla_metadata, num_splits), num_splits_t, block_kv_indices = ( self._build_decode_metadata(forward_batch, forward_batch.seq_lens) ) self.forward_metadata = VllmMLADecodeMetadata( mla_metadata, num_splits_t, block_kv_indices ) elif forward_batch.forward_mode.is_target_verify(): seq_lens = forward_batch.seq_lens + self.num_draft_tokens (mla_metadata, num_splits), num_splits_t, block_kv_indices = ( self._build_decode_metadata(forward_batch, seq_lens) ) self.forward_metadata = VllmMLADecodeMetadata( mla_metadata, num_splits_t, block_kv_indices ) else: # prefill/extend用triton backend -> 改用flash attn if not self.skip_prefill: # self.triton_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch) def init_cuda_graph_state( self, max_bs: int, max_num_tokens: int, block_kv_indices: Optional[torch.Tensor] = None, ): if block_kv_indices is None: cuda_graph_kv_indices = torch.full( (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE), 1, dtype=torch.int32, device="cuda", ) else: cuda_graph_kv_indices = block_kv_indices if self.num_draft_tokens: mla_metadata, num_splits = get_mla_metadata( torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), self.num_draft_tokens * self.num_q_heads, 1, ) else: mla_metadata, num_splits = get_mla_metadata( torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), self.num_q_heads, 1, ) self.cuda_graph_mla_metadata = mla_metadata self.cuda_graph_num_splits = num_splits self.cuda_graph_kv_indices = cuda_graph_kv_indices def init_forward_metadata_capture_cuda_graph( self, bs: int, num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional["SpecInput"], ): if forward_mode.is_decode_or_idle(): max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, seq_lens, None, self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), num_q_heads, 1 ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.forward_metadata = VllmMLADecodeMetadata( self.cuda_graph_mla_metadata, self.cuda_graph_num_splits[: bs + 1], self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], ) elif forward_mode.is_target_verify(): seq_lens = seq_lens + self.num_draft_tokens max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, seq_lens, None, self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1 ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.forward_metadata = VllmMLADecodeMetadata( self.cuda_graph_mla_metadata, self.cuda_graph_num_splits[: bs + 1], self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], ) else: if not self.skip_prefill: # self.triton_backend.init_forward_metadata_capture_cuda_graph( # bs, # num_tokens, # req_pool_indices, # seq_lens, # encoder_lens, # forward_mode, # spec_info, # ) self.flashattn_backend.init_forward_metadata_capture_cuda_graph( bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_mode, spec_info, ) def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional["SpecInput"], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): assert seq_lens_cpu is not None seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs], seq_lens, None, self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), num_q_heads, 1 ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ :bs, :max_seqlen_pad ] elif forward_mode.is_target_verify(): seq_lens = seq_lens[:bs] + self.num_draft_tokens seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE) create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs], seq_lens, None, self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1 ) self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[ :bs, :max_seqlen_pad ] else: if not self.skip_prefill: # self.triton_backend.init_forward_metadata_replay_cuda_graph( # bs, # req_pool_indices, # seq_lens, # seq_lens_sum, # encoder_lens, # forward_mode, # spec_info, # seq_lens_cpu, # ) self.flashattn_backend.init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, seq_lens, seq_lens_sum, encoder_lens, forward_mode, spec_info, seq_lens_cpu, ) def get_cuda_graph_seq_len_fill_value(self): return 1 def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, scaling: float): o, _ = flash_mla_with_kvcache( q=reshape_q, k_cache=k_cache_reshaped, block_table=block_table, cache_seqlens=cache_seqlens, head_dim_v=self.kv_lora_rank, tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, num_splits=self.forward_metadata.num_splits, softmax_scale=scaling, causal=True, ) return o def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, scaling: float): assert _has_flash_mla, "FP8 KV cache 需要flash_mla包" o, _ = flash_mla_with_kvcache_quantization( q=reshape_q, k_cache=k_cache_reshaped, block_table=block_table, cache_seqlens=cache_seqlens, head_dim_v=self.kv_lora_rank, tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, num_splits=self.forward_metadata.num_splits, softmax_scale=scaling, causal=True, is_fp8_kvcache=True, ) return o def forward_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: "RadixAttention", forward_batch: ForwardBatch, save_kv_cache: bool = True, ): cache_loc = forward_batch.out_cache_loc if k is not None: assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, ) bs = forward_batch.batch_size k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) if self.data_type in ( getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e4m3fnuz", None), getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2fnuz", None), ): o = self._call_fp8_decode( reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], forward_batch.seq_lens.to(torch.int32), layer.scaling, ) else: o = self._call_decode( reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], forward_batch.seq_lens.to(torch.int32), layer.scaling, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def forward_extend( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: "RadixAttention", forward_batch: ForwardBatch, save_kv_cache: bool = True, # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, sinks=None, ): if save_kv_cache: return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache) if (( forward_batch.forward_mode == ForwardMode.EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND) ): # flash_attn不支持fp8,fp8无法正常执行extend if not self.skip_prefill: # return self.triton_backend.forward_extend( # q, k, v, layer, forward_batch, save_kv_cache, sinks # ) return self.flashattn_backend.forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks ) else: raise RuntimeError("skip prefill but use forward_extend") cache_loc = forward_batch.out_cache_loc if k is not None: assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) bs = forward_batch.batch_size k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) if self.data_type in ( getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e4m3fnuz", None), getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2fnuz", None), ): o = self._call_fp8_decode( reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), layer.scaling, ) else: o = self._call_decode( reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), layer.scaling, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)