Commit e36f865d authored by linhai1's avatar linhai1
Browse files

Fix Bug.

parent 46da9556
...@@ -387,18 +387,30 @@ class DCUMLABackend(AttentionBackend): ...@@ -387,18 +387,30 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( if k_rope is None:
layer, forward_batch.token_to_kv_pool.set_kv_buffer(
cache_loc, layer,
k, cache_loc,
v, k,
) v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
...@@ -432,7 +444,9 @@ class DCUMLABackend(AttentionBackend): ...@@ -432,7 +444,9 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
sinks=None, q_rope = None,
k_rope = None,
sinks = None,
): ):
if ( if (
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
...@@ -444,7 +458,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -444,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks # q, k, v, layer, forward_batch, save_kv_cache, sinks
# ) # )
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
) )
else: else:
raise RuntimeError("skip prefill but use forward_extend") raise RuntimeError("skip prefill but use forward_extend")
...@@ -453,7 +467,21 @@ class DCUMLABackend(AttentionBackend): ...@@ -453,7 +467,21 @@ class DCUMLABackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) # forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......
...@@ -6,6 +6,8 @@ from typing import Optional, Union ...@@ -6,6 +6,8 @@ from typing import Optional, Union
import torch import torch
MAX_FLASH_ATTN_KERNEL_HEADDIM = 256
def flash_attn_with_kvcache( def flash_attn_with_kvcache(
q, q,
k_cache, k_cache,
...@@ -40,7 +42,46 @@ def flash_attn_with_kvcache( ...@@ -40,7 +42,46 @@ def flash_attn_with_kvcache(
sinks=None, sinks=None,
ver=3, ver=3,
): ):
return flash_attn_with_kvcache_interface( if cu_seqlens_q is not None and q.shape[0] != cu_seqlens_q.shape[0] * max_seqlen_q:
v_cache = v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1])
if v_cache.shape[-1] > MAX_FLASH_ATTN_KERNEL_HEADDIM:
out_1 = flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache[:, :, :MAX_FLASH_ATTN_KERNEL_HEADDIM], # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
out_2 = flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache[:, :, MAX_FLASH_ATTN_KERNEL_HEADDIM:], # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
return torch.cat([out_1, out_2], dim=-1)
else:
return flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
else:
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]), q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
......
...@@ -178,6 +178,7 @@ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [ ...@@ -178,6 +178,7 @@ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"trtllm_mla", "trtllm_mla",
"dcu_mla",
] ]
......
...@@ -1662,7 +1662,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1662,7 +1662,9 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, positions,
topk_indices, topk_indices,
): ):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: # if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS or \
(not forward_batch.forward_mode.is_decode() and self.current_attention_backend == 'dcu_mla'):
extra_args = {} extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch): if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = { extra_args = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment