Commit cd0b5891 authored by linhai1's avatar linhai1
Browse files

refer to flashmla to add decode backend.

parent d629db06
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch import torch
import triton import triton
...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata: ...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits: Optional[torch.Tensor] = None num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices
class DCUMLABackend(AttentionBackend): class DCUMLABackend(AttentionBackend):
def __init__( def __init__(
...@@ -92,49 +102,70 @@ class DCUMLABackend(AttentionBackend): ...@@ -92,49 +102,70 @@ class DCUMLABackend(AttentionBackend):
skip_prefill=False, skip_prefill=False,
) )
def _build_decode_metadata( def init_forward_metadata(self, forward_batch: ForwardBatch):
self,
forward_batch: ForwardBatch,
seq_lens: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
bs = forward_batch.batch_size bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full( block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device (bs, max_seqlen_pad),
) -1,
create_flashmla_kv_indices_triton[(bs,)]( dtype=torch.int32,
self.req_to_token, device=forward_batch.seq_lens.device
forward_batch.req_pool_indices, )
seq_lens, create_flashmla_kv_indices_triton[(bs,)](
None, self.req_to_token,
block_kv_indices, forward_batch.req_pool_indices,
self.req_to_token.stride(0), forward_batch.seq_lens,
max_seqlen_pad, None,
) block_kv_indices,
self.req_to_token.stride(0),
mla_metadata, num_splits = get_mla_metadata( max_seqlen_pad,
seq_lens.to(torch.int32), self.num_q_heads, 1 )
)
return (mla_metadata, num_splits), num_splits, block_kv_indices
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle(): mla_metadata, num_splits = get_mla_metadata(
(mla_metadata, num_splits), num_splits_t, block_kv_indices = ( forward_batch.seq_lens.to(torch.int32),
self._build_decode_metadata(forward_batch, forward_batch.seq_lens) self.num_q_heads,
1
) )
self.forward_metadata = VllmMLADecodeMetadata( self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices mla_metadata,
num_splits,
block_kv_indices
) )
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
seq_lens = forward_batch.seq_lens + self.num_draft_tokens seq_lens = forward_batch.seq_lens + self.num_draft_tokens
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, seq_lens) max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
) )
self.forward_metadata = VllmMLADecodeMetadata( self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices mla_metadata,
num_splits,
block_kv_indices
) )
else: else:
if not self.skip_prefill: if not self.skip_prefill:
...@@ -450,4 +481,95 @@ class DCUMLABackend(AttentionBackend): ...@@ -450,4 +481,95 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
if topk > 1:
raise ValueError(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps - 1):
self.attn_backends.append(
DCUMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=None,
)
)
def common_template(
self,
forward_batch: ForwardBatch,
call_fn: Callable,
):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, call_fn)
...@@ -871,9 +871,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -871,9 +871,16 @@ class FlashAttentionBackend(AttentionBackend):
return_softmax_lse=forward_batch.mha_return_lse, return_softmax_lse=forward_batch.mha_return_lse,
**kwargs, **kwargs,
) )
# if layer.layer_id == 0:
# print('### mha output, q, k, v', output.shape, q.shape, k.shape, v.shape)
#torch.Size([136, 16, 128]) torch.Size([136, 16, 192]) torch.Size([136, 16, 192]) torch.Size([136, 16, 128])
#torch.Size([7, 16, 128]) torch.Size([7, 16, 192]) torch.Size([7, 16, 192]) torch.Size([7, 16, 128])
#torch.Size([40, 16, 128]) torch.Size([40, 16, 192]) torch.Size([40, 16, 192]) torch.Size([40, 16, 128])
if forward_batch.mha_return_lse: if forward_batch.mha_return_lse:
output, lse, *rest = output output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous() lse = torch.transpose(lse, 0, 1).contiguous()
# if layer.layer_id == 0:
# print('###output, lse, q, k, v', output.shape, lse.shape, q.shape, k.shape, v.shape)
return output, lse return output, lse
return output return output
else: else:
...@@ -921,6 +928,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -921,6 +928,10 @@ class FlashAttentionBackend(AttentionBackend):
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits, num_splits=self.num_splits,
) )
# if layer.layer_id == 0:
# print('### mla output, q, k, v', result.shape, q_rope.shape, k_rope_cache.shape, c_kv_cache.shape)
#torch.Size([8, 16, 512]) torch.Size([8, 16, 64]) torch.Size([3318, 64, 1, 64]) torch.Size([3318, 64, 1, 512])
#torch.Size([286, 16, 512]) torch.Size([286, 16, 64]) torch.Size([3322, 64, 1, 64]) torch.Size([3322, 64, 1, 512])
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = ( o_expand, softmax_lse_expand, *rest_expand = (
......
...@@ -27,10 +27,7 @@ class DraftBackendFactory: ...@@ -27,10 +27,7 @@ class DraftBackendFactory:
backend_type = self.server_args.attention_backend backend_type = self.server_args.attention_backend
if backend_type not in backend_map: if backend_type not in backend_map:
if backend_type != "dcu_mla": raise ValueError(error_template.format(backend_type=backend_type))
raise ValueError(error_template.format(backend_type=backend_type))
else:
return backend_map["fa3"]()
return backend_map[backend_type]() return backend_map[backend_type]()
...@@ -49,6 +46,7 @@ class DraftBackendFactory: ...@@ -49,6 +46,7 @@ class DraftBackendFactory:
else self._create_triton_decode_backend else self._create_triton_decode_backend
), ),
"flashmla": self._create_flashmla_decode_backend, "flashmla": self._create_flashmla_decode_backend,
"dcu_mla": self._create_dcumla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend, "nsa": self._create_nsa_decode_backend,
...@@ -72,6 +70,7 @@ class DraftBackendFactory: ...@@ -72,6 +70,7 @@ class DraftBackendFactory:
else self._create_triton_prefill_backend else self._create_triton_prefill_backend
), ),
"flashmla": self._create_flashmla_prefill_backend, "flashmla": self._create_flashmla_prefill_backend,
"dcu_mla": self._create_dcumla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend, "nsa": self._create_nsa_prefill_backend,
...@@ -152,6 +151,15 @@ class DraftBackendFactory: ...@@ -152,6 +151,15 @@ class DraftBackendFactory:
return FlashMLAMultiStepDraftBackend( return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps self.draft_model_runner, self.topk, self.speculative_num_steps
) )
def _create_dcumla_decode_backend(self):
from sglang.srt.layers.attention.dcu_mla_backend import (
DCUMLAMultiStepDraftBackend,
)
return DCUMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self): def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import ( from sglang.srt.layers.attention.trtllm_mha_backend import (
...@@ -227,3 +235,9 @@ class DraftBackendFactory: ...@@ -227,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend." "flashmla prefill backend is not yet supported for draft extend."
) )
return None return None
def _create_dcumla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment