mha_fwd.h 1.08 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

namespace aiter {
namespace torch_itfs {
std::vector<at::Tensor> mha_fwd(at::Tensor& q,       // [b, sq, hq, d]
                                const at::Tensor& k, // [b, sk, hk, d]
                                const at::Tensor& v, // [b, sk, hk, d]
                                float p_dropout,
                                float softmax_scale,
                                bool is_causal,
                                int window_size_left,
                                int window_size_right,
                                bool return_softmax_lse,
                                bool return_dropout_randval,
                                std::optional<at::Tensor> out,                // [b, sq, hq, d]
                                std::optional<const at::Tensor> bias,         // [sq, sk]
                                std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
                                std::optional<at::Generator> gen);
} // namespace torch_itfs
} // namespace aiter