"src/vscode:/vscode.git/clone" did not exist on "ed759f0aee721f8520c5bf94d4b7bd7c0ae3dcbb"
Commit e36f865d authored by linhai1's avatar linhai1
Browse files

Fix Bug.

parent 46da9556
......@@ -387,18 +387,30 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......@@ -432,7 +444,9 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
sinks=None,
q_rope = None,
k_rope = None,
sinks = None,
):
if (
forward_batch.forward_mode == ForwardMode.EXTEND
......@@ -444,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
......@@ -453,7 +467,21 @@ class DCUMLABackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
# forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......
......@@ -6,6 +6,8 @@ from typing import Optional, Union
import torch
MAX_FLASH_ATTN_KERNEL_HEADDIM = 256
def flash_attn_with_kvcache(
q,
k_cache,
......@@ -40,7 +42,46 @@ def flash_attn_with_kvcache(
sinks=None,
ver=3,
):
return flash_attn_with_kvcache_interface(
if cu_seqlens_q is not None and q.shape[0] != cu_seqlens_q.shape[0] * max_seqlen_q:
v_cache = v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1])
if v_cache.shape[-1] > MAX_FLASH_ATTN_KERNEL_HEADDIM:
out_1 = flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache[:, :, :MAX_FLASH_ATTN_KERNEL_HEADDIM], # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
out_2 = flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache[:, :, MAX_FLASH_ATTN_KERNEL_HEADDIM:], # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
return torch.cat([out_1, out_2], dim=-1)
else:
return flash_attn_varlen_func_interface(
q=q, # (total_q, num_heads, head_size_og)
k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
v=v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=softmax_scale,
causal=causal,
)
else:
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache,
v_cache=v_cache,
......
......@@ -178,6 +178,7 @@ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashmla",
"cutlass_mla",
"trtllm_mla",
"dcu_mla",
]
......
......@@ -1662,7 +1662,9 @@ class DeepseekV2AttentionMLA(nn.Module):
positions,
topk_indices,
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
# if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS or \
(not forward_batch.forward_mode.is_decode() and self.current_attention_backend == 'dcu_mla'):
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment