mha_bwd.h 1.59 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

namespace aiter {
namespace torch_itfs {
std::vector<at::Tensor> mha_bwd(const at::Tensor& dout, // [b, sq, hq, d]
                                const at::Tensor& q,    // [b, sq, hq, d]
                                const at::Tensor& k,    // [b, sk, hk, d]
                                const at::Tensor& v,    // [b, sk, hk, d]
                                const at::Tensor& out,  // [b, sq, hq, d]
                                const at::Tensor& lse,  // [b, hq, sq]
                                float p_dropout,
                                float softmax_scale,
                                bool is_causal,
                                int window_size_left,
                                int window_size_right,
                                bool deterministic,
                                std::optional<at::Tensor> dq,                 // [b, sq, hq, d]
                                std::optional<at::Tensor> dk,                 // [b, sk, hk, d]
                                std::optional<at::Tensor> dv,                 // [b, sk, hk, d]
                                std::optional<at::Tensor> dbias_,             // [sq, sk]
                                std::optional<const at::Tensor> bias_,        // [sq, sk]
                                std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
                                std::optional<const at::Tensor> rng_state,
                                std::optional<at::Generator> gen);
} // namespace torch_itfs
} // namespace aiter