Commit d629db06 authored by linhai1's avatar linhai1
Browse files

add draft_extend support for dcu_mla.

parent 4d106b5f
......@@ -86,14 +86,6 @@ class DCUMLABackend(AttentionBackend):
self.skip_prefill = skip_prefill
if not skip_prefill:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend(
model_runner,
......@@ -109,7 +101,6 @@ class DCUMLABackend(AttentionBackend):
bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
# 参考vllm官方博客分页
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
)
......@@ -131,7 +122,6 @@ class DCUMLABackend(AttentionBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# decode用flashmla
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
)
......@@ -147,9 +137,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata, num_splits_t, block_kv_indices
)
else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
......@@ -241,15 +229,6 @@ class DCUMLABackend(AttentionBackend):
)
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
......@@ -321,16 +300,6 @@ class DCUMLABackend(AttentionBackend):
]
else:
if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
......@@ -387,30 +356,18 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......@@ -444,22 +401,14 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope = None,
k_rope = None,
sinks = None,
):
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
q, k, v, layer, forward_batch, save_kv_cache,
)
else:
raise RuntimeError("skip prefill but use forward_extend")
......@@ -468,21 +417,12 @@ class DCUMLABackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
# forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......
......@@ -668,9 +668,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
# if not self.use_mla:
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
# layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
......
......@@ -1662,9 +1662,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions,
topk_indices,
):
# if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS or \
(not forward_batch.forward_mode.is_decode() and self.current_attention_backend == 'dcu_mla'):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = {
......
......@@ -27,7 +27,10 @@ class DraftBackendFactory:
backend_type = self.server_args.attention_backend
if backend_type not in backend_map:
raise ValueError(error_template.format(backend_type=backend_type))
if backend_type != "dcu_mla":
raise ValueError(error_template.format(backend_type=backend_type))
else:
return backend_map["fa3"]()
return backend_map[backend_type]()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment