attention_ck.h 1.21 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

torch::Tensor pa_fwd_naive(torch::Tensor &Q, //   [num_seqs, num_heads, head_size]
                           torch::Tensor &K, //   [num_blocks, num_kv_heads, head_size/x, block_size, x]
                                             // or[num_batch, seqlen, num_kv_heads, head_size]
                           torch::Tensor &V, //   [num_blocks, num_kv_heads, head_size, block_size]
                                             // or[num_batch*seqlen, num_kv_heads, head_size]
                           torch::Tensor &block_tables,
                           torch::Tensor &context_lens,
                           torch::Tensor &k_dequant_scales,
                           torch::Tensor &v_dequant_scales,
                           const int max_seq_len,
                           const int num_kv_heads,
                           const float scale_s,
                           const float scale_k,
                           const float scale_v,
                           const int block_size,
                           const int quant_algo,
                           std::optional<torch::Tensor> &out_
                           // above are input
);