st_attn.cpp 700 Bytes
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <torch/extension.h>
#include <ATen/ATen.h>

#include <vector>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include <cuda_runtime.h>


#ifdef TK_COMPILE_ATTN
extern torch::Tensor sta_forward(
    torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o, int kernel_t_size, int kernel_w_size, int kernel_h_size, int text_length, bool process_text, bool has_text, int kernel_aspect_ratio_flag
); 
#endif


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "Sliding Block Attention Kernels"; // optional module docstring

#ifdef TK_COMPILE_ATTN
    m.def("sta_fwd",  torch::wrap_pybind_function(sta_forward), "sliding tile attention, assuming tile size is (6,8,8)");
#endif
}