attention_asm_mla.h 1.85 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

void mla_decode_stage1_asm_fwd(torch::Tensor &Q,                 //   [num_seqs, num_heads, head_size]
                               torch::Tensor &KV,                //   [num_page, page_size, num_kv_heads, head_size]
                               torch::Tensor &qo_indptr,         //   [batch_size+1]
                               torch::Tensor &kv_indptr,         //   [batch_size+1]
                               torch::Tensor &kv_page_indices,   //   [num_page_used]
                               torch::Tensor &kv_last_page_lens, //   [batch_size]
                               int max_seqlen_q,
                               float softmax_scale,
                               // following are output
                               torch::Tensor &splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim]
                               torch::Tensor &splitLse   //[batch_size, num_kv_splits, num_heads,  1]
);

void mla_prefill_asm_fwd(torch::Tensor &Q,                 //   [num_seqs, num_heads, head_size]
                         torch::Tensor &KV,                //   [num_page, page_size, num_kv_heads, kv_lora_rank + qk_rope_head_dim]
                         torch::Tensor &qo_indptr,         //   [batch_size+1]
                         torch::Tensor &kv_indptr,         //   [batch_size+1]
                         torch::Tensor &kv_page_indices,   //   [num_page_used]
                         torch::Tensor &kv_last_page_lens, //   [batch_size]
                         int max_seqlen_q,
                         float softmax_scale,
                         // following are output
                         torch::Tensor &splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim]
                         torch::Tensor &splitLse   //[batch_size, num_kv_splits, num_heads,  1]
);