Unverified Commit c49c1d92 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove 200us slow concat kernel (part 2: srt) (#7020)

parent 0f1dfa1e
...@@ -233,25 +233,49 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -233,25 +233,49 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, layer,
cache_loc, cache_loc,
k, k,
v, v,
) )
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) # Reshape inputs
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]
q_nope = q_nope.to(self.q_data_type)
q_rope = q_rope.to(self.q_data_type)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = cutlass_mla_decode( o = cutlass_mla_decode(
q_nope_and_q_pe=reshape_q.to(self.q_data_type), q_nope=q_nope,
q_pe=q_rope,
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
seq_lens=forward_batch.seq_lens.to(torch.int32), seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices, page_table=self.forward_metadata.block_kv_indices,
......
...@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def forward_absorb_core( def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
): ):
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer": if (
self.attention_backend == "fa3"
or self.attention_backend == "flashinfer"
or self.attention_backend == "cutlass_mla"
):
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment