Commit 3246cea1 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'

V0.5.4 dev linhai

See merge request OpenDAS/sglang!17
parents 93eb92f8 59b01a00
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch import torch
import triton import triton
...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata: ...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits: Optional[torch.Tensor] = None num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices
class DCUMLABackend(AttentionBackend): class DCUMLABackend(AttentionBackend):
def __init__( def __init__(
...@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
if not skip_prefill: if not skip_prefill:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend( self.flashattn_backend = FlashAttentionBackend(
model_runner, model_runner,
...@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata, num_splits_t, block_kv_indices mla_metadata, num_splits_t, block_kv_indices
) )
else: else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
...@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
) )
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph( self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
] ]
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph( self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -363,7 +344,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -363,7 +344,7 @@ class DCUMLABackend(AttentionBackend):
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor, def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float): scaling: float, k_scale=None, kv_cache_dtype=None):
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包" assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
o, _ = flash_mla_with_kvcache_quantization( o, _ = flash_mla_with_kvcache_quantization(
q=reshape_q, q=reshape_q,
...@@ -375,7 +356,8 @@ class DCUMLABackend(AttentionBackend): ...@@ -375,7 +356,8 @@ class DCUMLABackend(AttentionBackend):
num_splits=self.forward_metadata.num_splits, num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling, softmax_scale=scaling,
causal=True, causal=True,
is_fp8_kvcache=True, k_scale=k_scale,
kv_cache_dtype=kv_cache_dtype,
) )
return o return o
...@@ -412,14 +394,29 @@ class DCUMLABackend(AttentionBackend): ...@@ -412,14 +394,29 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None), getattr(torch, "float8_e5m2fnuz", None),
): ):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], reshape_q,
forward_batch.seq_lens.to(torch.int32), layer.scaling, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
kv_cache_dtype=kv_cache_dtype,
) )
else: else:
o = self._call_decode( o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], reshape_q,
forward_batch.seq_lens.to(torch.int32), layer.scaling, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
layer.scaling,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -432,7 +429,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -432,7 +429,6 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
...@@ -445,11 +441,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -445,11 +441,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
): ):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill: if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
) )
...@@ -474,14 +466,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -474,14 +466,27 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None), getattr(torch, "float8_e5m2fnuz", None),
): ):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz or \
k_cache_reshaped.dtype == torch.float8_e4m3fn:
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz or \
k_cache_reshaped.dtype == torch.float8_e5m2:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale.to(torch.float32),
kv_cache_dtype=kv_cache_dtype,
) )
else: else:
o = self._call_decode( o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs], reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
) )
...@@ -489,3 +494,95 @@ class DCUMLABackend(AttentionBackend): ...@@ -489,3 +494,95 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
if topk > 1:
raise ValueError(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps - 1):
self.attn_backends.append(
DCUMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=None,
)
)
def common_template(
self,
forward_batch: ForwardBatch,
call_fn: Callable,
):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, call_fn)
...@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys. # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype = q.dtype
if ( if (
self.kv_cache_dtype_str != "auto" self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256 and layer.head_dim <= 256
...@@ -828,7 +829,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -828,7 +829,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.attn_attend_prefix_cache is not None forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
k_descale = k_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
v_descale = v_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache assert not get_global_server_args().disable_chunked_prefix_cache
...@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse assert forward_batch.mha_return_lse
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
...@@ -855,11 +858,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -855,11 +858,10 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs, **kwargs,
) )
else: else:
# MHA for extend part of sequence without attending prefix kv cache
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q, cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
......
...@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module):
# Fetch latent cache from memory pool with precomputed chunked kv indices # Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
self.attn_mha.layer_id self.attn_mha.layer_id
) ).to(q.dtype)
latent_cache = ( latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]] latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous() .contiguous()
......
...@@ -46,6 +46,7 @@ class DraftBackendFactory: ...@@ -46,6 +46,7 @@ class DraftBackendFactory:
else self._create_triton_decode_backend else self._create_triton_decode_backend
), ),
"flashmla": self._create_flashmla_decode_backend, "flashmla": self._create_flashmla_decode_backend,
"dcu_mla": self._create_dcumla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend, "nsa": self._create_nsa_decode_backend,
...@@ -69,6 +70,7 @@ class DraftBackendFactory: ...@@ -69,6 +70,7 @@ class DraftBackendFactory:
else self._create_triton_prefill_backend else self._create_triton_prefill_backend
), ),
"flashmla": self._create_flashmla_prefill_backend, "flashmla": self._create_flashmla_prefill_backend,
"dcu_mla": self._create_dcumla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend, "nsa": self._create_nsa_prefill_backend,
...@@ -149,6 +151,15 @@ class DraftBackendFactory: ...@@ -149,6 +151,15 @@ class DraftBackendFactory:
return FlashMLAMultiStepDraftBackend( return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps self.draft_model_runner, self.topk, self.speculative_num_steps
) )
def _create_dcumla_decode_backend(self):
from sglang.srt.layers.attention.dcu_mla_backend import (
DCUMLAMultiStepDraftBackend,
)
return DCUMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self): def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import ( from sglang.srt.layers.attention.trtllm_mha_backend import (
...@@ -224,3 +235,9 @@ class DraftBackendFactory: ...@@ -224,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend." "flashmla prefill backend is not yet supported for draft extend."
) )
return None return None
def _create_dcumla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment