mha_batch_prefill.h 1.28 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

namespace aiter {
namespace torch_itfs {
std::vector<at::Tensor>
mha_batch_prefill(at::Tensor& q,                  // [total_q, hq, d]
                  const at::Tensor& k,            // [total_k, hk, d]
                  const at::Tensor& v,            // [total_k, hk, d]
                  const at::Tensor& cu_seqlens_q, // [b+1]
                  const at::Tensor& kv_indptr,    // [b+1]
                  const at::Tensor& kv_page_indices,
                  int max_seqlen_q,
                  int max_seqlen_k,
                  float p_dropout,
                  float softmax_scale,
                  float logits_soft_cap,
                  bool zero_tensors,
                  bool is_causal,
                  int window_size_left,
                  int window_size_right,
                  bool return_softmax_lse,
                  bool return_dropout_randval,
                  std::optional<at::Tensor> out_,                // [total_q, hq, d]
                  std::optional<const at::Tensor> bias_,         // [total_q, max_seqlen_k]
                  std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
                  std::optional<at::Generator> gen_);
} // namespace torch_itfs
} // namespace aiter