Commit d629db06 authored by linhai1's avatar linhai1
Browse files

add draft_extend support for dcu_mla.

parent 4d106b5f
...@@ -86,14 +86,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -86,14 +86,6 @@ class DCUMLABackend(AttentionBackend):
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
if not skip_prefill: if not skip_prefill:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend( self.flashattn_backend = FlashAttentionBackend(
model_runner, model_runner,
...@@ -109,7 +101,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -109,7 +101,6 @@ class DCUMLABackend(AttentionBackend):
bs = forward_batch.batch_size bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
# 参考vllm官方博客分页
block_kv_indices = torch.full( block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device (bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
) )
...@@ -131,7 +122,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -131,7 +122,6 @@ class DCUMLABackend(AttentionBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
# decode用flashmla
(mla_metadata, num_splits), num_splits_t, block_kv_indices = ( (mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, forward_batch.seq_lens) self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
) )
...@@ -147,9 +137,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -147,9 +137,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata, num_splits_t, block_kv_indices mla_metadata, num_splits_t, block_kv_indices
) )
else: else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
...@@ -241,15 +229,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -241,15 +229,6 @@ class DCUMLABackend(AttentionBackend):
) )
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph( self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -321,16 +300,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -321,16 +300,6 @@ class DCUMLABackend(AttentionBackend):
] ]
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph( self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -387,30 +356,18 @@ class DCUMLABackend(AttentionBackend): ...@@ -387,30 +356,18 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
if k_rope is None: forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_kv_buffer( layer,
layer, cache_loc,
cache_loc, k,
k, v,
v, )
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
...@@ -444,22 +401,14 @@ class DCUMLABackend(AttentionBackend): ...@@ -444,22 +401,14 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope = None,
k_rope = None,
sinks = None,
): ):
if ( if (
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
): ):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill: if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks q, k, v, layer, forward_batch, save_kv_cache,
) )
else: else:
raise RuntimeError("skip prefill but use forward_extend") raise RuntimeError("skip prefill but use forward_extend")
...@@ -468,21 +417,12 @@ class DCUMLABackend(AttentionBackend): ...@@ -468,21 +417,12 @@ class DCUMLABackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
# forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
if k_rope is None: layer,
forward_batch.token_to_kv_pool.set_kv_buffer( cache_loc,
layer, k,
cache_loc, v,
k, )
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......
...@@ -668,9 +668,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -668,9 +668,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if not self.use_mla: # if not self.use_mla:
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale # layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v
) )
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
......
...@@ -1662,9 +1662,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1662,9 +1662,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, positions,
topk_indices, topk_indices,
): ):
# if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS or \
(not forward_batch.forward_mode.is_decode() and self.current_attention_backend == 'dcu_mla'):
extra_args = {} extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch): if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = { extra_args = {
......
...@@ -27,7 +27,10 @@ class DraftBackendFactory: ...@@ -27,7 +27,10 @@ class DraftBackendFactory:
backend_type = self.server_args.attention_backend backend_type = self.server_args.attention_backend
if backend_type not in backend_map: if backend_type not in backend_map:
raise ValueError(error_template.format(backend_type=backend_type)) if backend_type != "dcu_mla":
raise ValueError(error_template.format(backend_type=backend_type))
else:
return backend_map["fa3"]()
return backend_map[backend_type]() return backend_map[backend_type]()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment