Commit 27815038 authored by niuhb's avatar niuhb
Browse files

mtp

parent de61a992
...@@ -151,8 +151,59 @@ class DCUMLABackend(AttentionBackend): ...@@ -151,8 +151,59 @@ class DCUMLABackend(AttentionBackend):
) )
else: else:
if not self.skip_prefill: if not self.skip_prefill:
self.flashattn_backend.init_forward_metadata(forward_batch) # === DRAFT_EXTEND_V2 MLA metadata === nhb
if forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2:
bs = forward_batch.batch_size
seq_lens_cpu = forward_batch.seq_lens_cpu
seq_lens = forward_batch.seq_lens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
# 调用 Triton kernel 生成 block_kv_indices
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token.to(torch.int32),
req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
kv_start_idx = None,
kv_indices_ptr = block_kv_indices.to(torch.int32),
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
)
# save forward_metadata
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
...@@ -431,7 +482,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -431,7 +482,7 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if save_kv_cache: if save_kv_cache and self.num_draft_tokens == 0: #nhb
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache) return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if (( if ((
......
...@@ -598,6 +598,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -598,6 +598,7 @@ class FlashAttentionBackend(AttentionBackend):
if ( if (
any(forward_batch.extend_prefix_lens_cpu) any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2 #nhb
): ):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
...@@ -668,9 +669,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -668,9 +669,9 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if not self.use_mla: if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
) )
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
......
...@@ -1940,7 +1940,7 @@ class Scheduler( ...@@ -1940,7 +1940,7 @@ class Scheduler(
batch.spec_info = batch_result.next_draft_input batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices batch.spec_info.future_indices = future_indices
batch.sampling_info.is_all_greedy = True #nhb
# batch.spec_info = EagleDraftInput( # batch.spec_info = EagleDraftInput(
# future_indices=future_indices, # future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done, # verify_done=batch_result.next_draft_input.verify_done,
......
...@@ -129,6 +129,7 @@ class ForwardMode(IntEnum): ...@@ -129,6 +129,7 @@ class ForwardMode(IntEnum):
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2 #nhb
) )
def is_cuda_graph(self): def is_cuda_graph(self):
......
...@@ -237,7 +237,14 @@ class DraftBackendFactory: ...@@ -237,7 +237,14 @@ class DraftBackendFactory:
return None return None
def _create_dcumla_prefill_backend(self): def _create_dcumla_prefill_backend(self):
logger.warning( # logger.warning(
"flashmla prefill backend is not yet supported for draft extend." # "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
) )
return None
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment